diff --git a/.circleci/config.yml b/.circleci/config.yml index 90098519eee0f..aa4624c9ea9b3 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -132,10 +132,10 @@ jobs: - checkout - run: ./build_tools/circle/checkout_merge_commit.sh - restore_cache: - key: linux-arm64-{{ .Branch }} + key: linux-arm64-ccache-v1-{{ .Branch }} - run: ./build_tools/circle/build_test_arm.sh - save_cache: - key: linux-arm64-{{ .Branch }} + key: linux-arm64-ccache-v1-{{ .Branch }} paths: - ~/.cache/ccache - ~/.cache/pip diff --git a/doc/conftest.py b/doc/conftest.py index 10253efeabf98..73e1b244a3efa 100644 --- a/doc/conftest.py +++ b/doc/conftest.py @@ -136,6 +136,13 @@ def pytest_runtest_setup(item): setup_preprocessing() elif fname.endswith("statistical_inference/unsupervised_learning.rst"): setup_unsupervised_learning() + elif fname.endswith("metadata_routing.rst"): + # TODO: remove this once implemented + # Skip metarouting because is it is not fully implemented yet + raise SkipTest( + "Skipping doctest for metadata_routing.rst because it " + "is not fully implemented yet" + ) rst_files_requiring_matplotlib = [ "modules/partial_dependence.rst", diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst new file mode 100644 index 0000000000000..de04ecf022415 --- /dev/null +++ b/doc/metadata_routing.rst @@ -0,0 +1,215 @@ + +.. _metadata_routing: + +.. TODO: update doc/conftest.py once document is updated and examples run. + +Metadata Routing +================ + +This guide demonstrates how metadata such as ``sample_weight`` can be routed +and passed along to estimators, scorers, and CV splitters through +meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass +metadata to a method such as ``fit`` or ``score``, the object accepting the +metadata, must *request* it. For estimators and splitters this is done via +``set_*_request`` methods, e.g. ``set_fit_request(...)``, and for scorers this +is done via ``set_score_request`` method. For grouped splitters such as +``GroupKFold`` a ``groups`` parameter is requested by default. This is best +demonstrated by the following examples. + +If you are developing a scikit-learn compatible estimator or meta-estimator, +you can check our related developer guide: +:ref:`sphx_glr_auto_examples_plot_metadata_routing.py`. + +Usage Examples +************** +Here we present a few examples to show different common use-cases. The examples +in this section require the following imports and data:: + + >>> import numpy as np + >>> from sklearn.metrics import make_scorer, accuracy_score + >>> from sklearn.linear_model import LogisticRegressionCV + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.model_selection import GridSearchCV + >>> from sklearn.model_selection import GroupKFold + >>> from sklearn.feature_selection import SelectKBest + >>> from sklearn.utils.metadata_requests import RequestType + >>> from sklearn.pipeline import make_pipeline + >>> n_samples, n_features = 100, 4 + >>> X = np.random.rand(n_samples, n_features) + >>> y = np.random.randint(0, 2, size=n_samples) + >>> my_groups = np.random.randint(0, 10, size=n_samples) + >>> my_weights = np.random.rand(n_samples) + >>> my_other_weights = np.random.rand(n_samples) + +Weighted scoring and fitting +---------------------------- + +Here ``GroupKFold`` requests ``groups`` by default. However, we need to +explicitly request weights for our scorer and for ``LogisticRegressionCV``. +Both of these *consumers* know how to use metadata called ``"sample_weight"``:: + + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).set_fit_request(sample_weight=True) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Note that in this example, ``my_weights`` is passed to both the scorer and +:class:`~linear_model.LogisticRegressionCV`. + +Error handling: if ``props={"sample_weigh": my_weights, ...}`` were passed +(note the typo), ``cross_validate`` would raise an error, since +``sample_weigh`` was not requested by any of its children. + +Weighted scoring and unweighted fitting +--------------------------------------- + +All scikit-learn estimators requires weights to be either explicitly requested +or not requested (i.e. ``UNREQUESTED``) when used in another router such as a +``Pipeline`` or a ``*GridSearchCV``. To perform a unweighted fit, we need to +configure :class:`~linear_model.LogisticRegressionCV` to not request sample +weights, so that :func:`~model_selection.cross_validate` does not pass the +weights along:: + + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).set_fit_request(sample_weight=RequestType.UNREQUESTED) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Note the usage of ``RequestType`` which in this case is equivalent to +``False``; the type is explained further at the end of this document. + +If :class:`~linear_model.LogisticRegressionCV` does not call +``set_fit_request``, :func:`~model_selection.cross_validate` will raise an +error because weights is passed in but +:class:`~linear_model.LogisticRegressionCV` would not be explicitly configured +to recognize the weights. + +Unweighted feature selection +---------------------------- + +Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and +therefore `"sample_weight"` is not routed to it:: + + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).set_fit_request(sample_weight=True) + >>> sel = SelectKBest(k=2) + >>> pipe = make_pipeline(sel, lr) + >>> cv_results = cross_validate( + ... pipe, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Advanced: Different scoring and fitting weights +----------------------------------------------- + +Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting the key +``sample_weight``, we can use aliases to pass different weights to different +consumers. In this example, we pass ``scoring_weight`` to the scorer, and +``fitting_weight`` to ``LogisticRegressionCV``:: + + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( + ... sample_weight="scoring_weight" + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).set_fit_request(sample_weight="fitting_weight") + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={ + ... "scoring_weight": my_weights, + ... "fitting_weight": my_other_weights, + ... "groups": my_groups, + ... }, + ... scoring=weighted_acc, + ... ) + +API Interface +************* + +A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which +accepts and uses some metadata in at least one of its methods (``fit``, +``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``). +Meta-estimators which only forward the metadata to other objects (the child +estimator, scorers, or splitters) and don't use the metadata themselves are not +consumers. (Meta)Estimators which route metadata to other objects are +*routers*. An (meta)estimator can be a consumer and a router at the same time. +(Meta)Estimators and splitters expose a ``set_*_request`` method for each +method which accepts at least one metadata. For instance, if an estimator +supports ``sample_weight`` in ``fit`` and ``score``, it exposes +``estimator.set_fit_request(sample_weight=value)`` and +``estimator.set_score_request(sample_weight=value)``. Here ``value`` can be: + +- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``. + This means if the metadata is provided, it will be used, otherwise no error + is raised. +- ``RequestType.UNREQUESTED`` or ``False``: method does not request a + ``sample_weight``. +- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if + ``sample_weight`` is passed. This is in almost all cases the default value + when an object is instantiated and ensures the user sets the metadata + requests explicitly when a metadata is passed. The only exception are + ``Group*Fold`` splitters. +- ``"param_name"``: if this estimator is used in a meta-estimator, the + meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this + estimator. This means the mapping between the metadata required by the + object, e.g. ``sample_weight`` and what is provided by the user, e.g. + ``my_weights`` is done at the router level, and not by the object, e.g. + estimator, itself. + +For the scorers, this is done the same way, using ``set_score_request`` method. + +If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata +request for all objects which potentially can accept ``sample_weight`` should +be set by the user, otherwise an error is raised by the router object. For +example, the following code raises an error, since it hasn't been explicitly +specified whether ``sample_weight`` should be passed to the estimator's scorer +or not:: + + >>> param_grid = {"C": [0.1, 1]} + >>> lr = LogisticRegression().set_fit_request(sample_weight=True) + >>> try: + ... GridSearchCV( + ... estimator=lr, param_grid=param_grid + ... ).fit(X, y, sample_weight=my_weights) + ... except ValueError as e: + ... print(e) + sample_weight is passed but is not explicitly set as requested or not for + LogisticRegression.score + +The issue can be fixed by explicitly setting the request value:: + + >>> lr = LogisticRegression().set_fit_request( + ... sample_weight=True + ... ).set_score_request(sample_weight=False) diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index c6838556d50ad..a9c61b7a75170 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -34,6 +34,7 @@ Base classes base.DensityMixin base.RegressorMixin base.TransformerMixin + base.MetaEstimatorMixin feature_selection.SelectorMixin Functions @@ -1652,6 +1653,12 @@ Plotting utils.validation.column_or_1d utils.validation.has_fit_parameter utils.all_estimators + utils.metadata_routing.RequestType + utils.metadata_routing.get_routing_for_object + utils.metadata_routing.MetadataRouter + utils.metadata_routing.MetadataRequest + utils.metadata_routing.MethodMapping + utils.metadata_routing.process_routing Utilities from joblib: diff --git a/doc/modules/model_evaluation.rst b/doc/modules/model_evaluation.rst index 34412576f80aa..6beadd545df6c 100644 --- a/doc/modules/model_evaluation.rst +++ b/doc/modules/model_evaluation.rst @@ -228,6 +228,12 @@ the following two rules: Again, by convention higher numbers are better, so if your scorer returns loss, that value should be negated. +- Advanced: If it requires extra metadata to be passed to it, it should expose + a ``get_metadata_routing`` method returning the requested metadata. The user + should be able to set the requested metadata via a ``set_score_request`` + method. Please see :ref:`User Guide ` for more details. + + .. note:: **Using custom scorers in functions where n_jobs > 1** While defining the custom scoring function alongside the calling function diff --git a/doc/user_guide.rst b/doc/user_guide.rst index 7237938784046..2e92fec1c63da 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -27,6 +27,7 @@ User Guide visualizations.rst data_transforms.rst datasets.rst + metadata_routing.rst computing.rst model_persistence.rst common_pitfalls.rst diff --git a/examples/plot_metadata_routing.py b/examples/plot_metadata_routing.py new file mode 100644 index 0000000000000..91c3651699ee4 --- /dev/null +++ b/examples/plot_metadata_routing.py @@ -0,0 +1,617 @@ +""" +================ +Metadata Routing +================ + +.. currentmodule:: sklearn + +This document shows how you can use the :ref:`metadata routing mechanism +` in scikit-learn to route metadata through meta-estimators +to the estimators consuming them. To better understand the rest of the +document, we need to introduce two concepts: routers and consumers. A router is +an object, in most cases a meta-estimator, which forwards given data and +metadata to other objects and estimators. A consumer, on the other hand, is an +object which accepts and uses a certain given metadata. For instance, an +estimator taking into account ``sample_weight`` in its :term:`fit` method is a +consumer of ``sample_weight``. It is possible for an object to be both a router +and a consumer. For instance, a meta-estimator may take into account +``sample_weight`` in certain calculations, but it may also route it to the +underlying estimator. + +First a few imports and some random data for the rest of the script. +""" +# %% + +import numpy as np +import warnings +from pprint import pprint +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.base import RegressorMixin +from sklearn.base import MetaEstimatorMixin +from sklearn.base import TransformerMixin +from sklearn.base import clone +from sklearn.utils.metadata_routing import RequestType +from sklearn.utils.metadata_routing import get_routing_for_object +from sklearn.utils.metadata_routing import MetadataRouter +from sklearn.utils.metadata_routing import MethodMapping +from sklearn.utils.metadata_routing import process_routing +from sklearn.utils.validation import check_is_fitted +from sklearn.linear_model import LinearRegression + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +my_groups = np.random.randint(0, 10, size=N) +my_weights = np.random.rand(N) +my_other_weights = np.random.rand(N) + +# %% +# This utility function is a dummy to check if a metadata is passed. + + +def check_metadata(obj, **kwargs): + for key, value in kwargs.items(): + if value is not None: + print( + f"Received {key} of length = {len(value)} in {obj.__class__.__name__}." + ) + else: + print(f"{key} is None in {obj.__class__.__name__}.") + + +# %% +# A utility function to nicely print the routing information of an object +def print_routing(obj): + pprint(obj.get_metadata_routing()._serialize()) + + +# %% +# Estimators +# ---------- +# Here we demonstrate how an estimator can expose the required API to support +# metadata routing as a consumer. Imagine a simple classifier accepting +# ``sample_weight`` as a metadata on its ``fit`` and ``groups`` in its +# ``predict`` method: + + +class ExampleClassifier(ClassifierMixin, BaseEstimator): + def fit(self, X, y, sample_weight=None): + check_metadata(self, sample_weight=sample_weight) + # all classifiers need to expose a classes_ attribute once they're fit. + self.classes_ = np.array([0, 1]) + return self + + def predict(self, X, groups=None): + check_metadata(self, groups=groups) + # return a constant value of 1, not a very smart classifier! + return np.ones(len(X)) + + +# %% +# The above estimator now has all it needs to consume metadata. This is +# accomplished by some magic done in :class:`~base.BaseEstimator`. There are +# now three methods exposed by the above class: ``set_fit_request``, +# ``set_predict_request``, and ``get_metadata_routing``. There is also a +# ``set_score_request`` for ``sample_weight`` which is present since +# :class:`~base.ClassifierMixin` implements a ``score`` method accepting +# ``sample_weight``. The same applies to regressors which inherit from +# :class:`~base.RegressorMixin`. +# +# By default, no metadata is requested, which we can see as: + +print_routing(ExampleClassifier()) + +# %% +# The above output means that ``sample_weight`` and ``groups`` are not +# requested, but if a router is given those metadata, it should raise an error, +# since the user has not explicitly set whether they are required or not. The +# same is true for ``sample_weight`` in the ``score`` method, which is +# inherited from :class:`~base.ClassifierMixin`. In order to explicitly set +# request values for those metadata, we can use these methods: + +est = ( + ExampleClassifier() + .set_fit_request(sample_weight=False) + .set_predict_request(groups=True) + .set_score_request(sample_weight=False) +) +print_routing(est) + +# %% +# As you can see, the metadata have now explicit request values, one is +# requested and one is not. Instead of ``True`` and ``False``, we could also +# use the :class:`~sklearn.utils.metadata_routing.RequestType` values. + +est = ( + ExampleClassifier() + .set_fit_request(sample_weight=RequestType.UNREQUESTED) + .set_predict_request(groups=RequestType.REQUESTED) + .set_score_request(sample_weight=RequestType.UNREQUESTED) +) +print_routing(est) + +# %% +# .. note :: +# Please note that as long as the above estimator is not used in another +# meta-estimator, the user does not need to set any requests for the +# metadata and the set values are ignored, since a consumer does not +# validate or route given metadata. A simple usage of the above estimator +# would work as expected. + +est = ExampleClassifier() +est.fit(X, y, sample_weight=my_weights) +est.predict(X[:3, :], groups=my_groups) + +# %% +# Now let's have a meta-estimator, which doesn't do much other than routing the +# metadata. + + +class MetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def get_metadata_routing(self): + # This method defines the routing for this meta-estimator. + # In order to do so, a `MetadataRouter` instance is created, and the + # right routing is added to it. More explanations follow. + router = MetadataRouter(owner=self.__class__.__name__).add( + estimator=self.estimator, method_mapping="one-to-one" + ) + return router + + def fit(self, X, y, **fit_params): + # meta-estimators are responsible for validating the given metadata. + # `get_routing_for_object` is a safe way to construct a + # `MetadataRouter` or a `MetadataRequest` from the given object. + request_router = get_routing_for_object(self) + request_router.validate_metadata(params=fit_params, method="fit") + # we can use provided utility methods to map the given metadata to what + # is required by the underlying estimator. Here `method` refers to the + # parent's method, i.e. `fit` in this example. + routed_params = request_router.route_params(params=fit_params, caller="fit") + + # the output has a key for each object's method which is used here, + # i.e. parent's `fit` method, containing the metadata which should be + # routed to them, based on the information provided in + # `get_metadata_routing`. + self.estimator_ = clone(self.estimator).fit(X, y, **routed_params.estimator.fit) + self.classes_ = self.estimator_.classes_ + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + # same as in `fit`, we validate the given metadata + request_router = get_routing_for_object(self) + request_router.validate_metadata(params=predict_params, method="predict") + # and then prepare the input to the underlying `predict` method. + routed_params = request_router.route_params( + params=predict_params, caller="predict" + ) + return self.estimator_.predict(X, **routed_params.estimator.predict) + + +# %% +# Let's break down different parts of the above code. +# +# First, the :meth:`~utils.metadata_routing.get_routing_for_object` takes an +# estimator (``self``) and returns a +# :class:`~utils.metadata_requests.MetadataRouter` or a +# :class:`~utils.metadata_requests.MetadataRequest` based on the output of the +# estimator's ``get_metadata_routing`` method. +# +# Then in each method, we use the ``route_params`` method to construct a +# dictionary of the form ``{"object_name": {"method_name": {"metadata": +# value}}}`` to pass to the underlying estimator's method. The ``object_name`` +# (``estimator`` in the above ``routed_params.estimator.fit`` example) is the +# same as the one ``add``ed in the ``get_metadata_routing``. +# ``validate_metadata`` makes sure all given metadata are requested. This is to +# avoid silent bugs, and this is how it will work: + +est = MetaClassifier(estimator=ExampleClassifier().set_fit_request(sample_weight=True)) +est.fit(X, y, sample_weight=my_weights) + +# %% +# Note that the above example checks that ``sample_weight`` is correctly passed +# to ``ExampleClassifier``, or else it would print that ``sample_weight`` is +# ``None``: + +est.fit(X, y) + +# %% +# If we pass an unknown metadata, it will be caught: +try: + est.fit(X, y, test=my_weights) +except TypeError as e: + print(e) + +# %% +# And if we pass something which is not explicitly requested: +try: + est.fit(X, y, sample_weight=my_weights).predict(X, groups=my_groups) +except ValueError as e: + print(e) + +# %% +# Also, if we explicitly say it's not requested, but pass it: +est = MetaClassifier( + estimator=ExampleClassifier() + .set_fit_request(sample_weight=True) + .set_predict_request(groups=False) +) +try: + est.fit(X, y, sample_weight=my_weights).predict(X[:3, :], groups=my_groups) +except TypeError as e: + print(e) + +# %% +# Another concept to introduce is aliased metadata. This is when an estimator +# requests a metadata with a different name than the default value. For +# instance, in a setting where there are two estimators in a pipeline, one +# could request ``sample_weight1`` and the other ``sample_weight2``. Note that +# this doesn't change what the estimator expects, it only tells the +# meta-estimator how to map provided metadata to what's required. Here's an +# example, where we pass ``aliased_sample_weight`` to the meta-estimator, but +# the meta-estimator understands that ``aliased_sample_weight`` is an alias for +# ``sample_weight``, and passes it as ``sample_weight`` to the underlying +# estimator: +est = MetaClassifier( + estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight") +) +est.fit(X, y, aliased_sample_weight=my_weights) + +# %% +# And passing ``sample_weight`` here will fail since it is requested with an +# alias and ``sample_weight`` with that name is not requested: +try: + est.fit(X, y, sample_weight=my_weights) +except TypeError as e: + print(e) + +# %% +# This leads us to the ``get_metadata_routing``. The way routing works in +# scikit-learn is that consumers request what they need, and routers pass that +# along. Additionally, a router exposes what it requires itself so that it can +# be used inside another router, e.g. a pipeline inside a grid search object. +# The output of the ``get_metadata_routing`` which is a dictionary +# representation of a :class:`~utils.metadata_routing.MetadataRouter`, includes +# the complete tree of requested metadata by all nested objects and their +# corresponding method routings, i.e. which method of a sub-estimator is used +# in which method of a meta-estimator: + +print_routing(est) + +# %% +# As you can see, the only metadata requested for method ``fit`` is +# ``"sample_weight"`` with ``"aliased_sample_weight"`` as the alias. The +# ``MetadataRouter`` class enables us to easily create the routing object which +# would create the output we need for our ``get_metadata_routing``. In the +# above implementation, ``mapping="one-to-one"`` means there is a one to one +# mapping between sub-estimator's methods and meta-estimator's ones, i.e. +# ``fit`` used in ``fit`` and so on. In order to understand how aliases work in +# meta-estimators, imagine our meta-estimator inside another one: + +meta_est = MetaClassifier(estimator=est).fit(X, y, aliased_sample_weight=my_weights) + +# %% +# In the above example, this is how each ``fit`` method will call the +# sub-estimator's ``fit``:: +# +# meta_est.fit(X, y, aliased_sample_weight=my_weights): +# ... # this estimator (est), expects aliased_sample_weight as seen above +# self.estimator_.fit(X, y, aliased_sample_weight=aliased_sample_weight): +# ... # now est passes aliased_sample_weight's value as sample_weight, +# # which is expected by the sub-estimator +# self.estimator_.fit(X, y, sample_weight=aliased_sample_weight) +# ... + +# %% +# Router and Consumer +# ------------------- +# To show how a slightly more complicated case would work, consider a case +# where a meta-estimator uses some metadata, but it also routes them to an +# underlying estimator. In this case, this meta-estimator is a consumer and a +# router at the same time. This is how we can implement one, and it is very +# similar to what we had before, with a few tweaks. + + +class RouterConsumerClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight, **fit_params): + if self.estimator is None: + raise ValueError("estimator cannot be None!") + + check_metadata(self, sample_weight=sample_weight) + + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + + # meta-estimators are responsible for validating the given metadata + request_router = get_routing_for_object(self) + request_router.validate_metadata(params=fit_params, method="fit") + # we can use provided utility methods to map the given metadata to what + # is required by the underlying estimator + params = request_router.route_params(params=fit_params, caller="fit") + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + self.classes_ = self.estimator_.classes_ + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + # same as in ``fit``, we validate the given metadata + request_router = get_routing_for_object(self) + request_router.validate_metadata(params=predict_params, method="predict") + # and then prepare the input to the underlying ``predict`` method. + params = request_router.route_params(params=predict_params, caller="predict") + return self.estimator_.predict(X, **params.estimator.predict) + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add_self(self) + .add(estimator=self.estimator, method_mapping="one-to-one") + ) + return router + + +# %% +# The key parts where the above estimator differs from our previous +# meta-estimator is accepting ``sample_weight`` explicitly in ``fit`` and +# including it in ``fit_params``. Making ``sample_weight`` an explicit argument +# makes sure ``set_fit_request(sample_weight=...)`` is present for this class. +# +# In ``get_metadata_routing``, we add ``self`` to the routing using +# ``add_self``. Now let's look at some examples: + +# %% +# - No metadata requested +est = RouterConsumerClassifier(estimator=ExampleClassifier()) +print_routing(est) + + +# %% +# - ``sample_weight`` requested by child estimator +est = RouterConsumerClassifier( + estimator=ExampleClassifier().set_fit_request(sample_weight=True) +) +print_routing(est) + +# %% +# - ``sample_weight`` requested by meta-estimator +est = RouterConsumerClassifier(estimator=ExampleClassifier()).set_fit_request( + sample_weight=True +) +print_routing(est) + +# %% +# Note the difference in the requested meatada representations above. +# +# - We can also alias the metadata to pass different values to them: + +est = RouterConsumerClassifier( + estimator=ExampleClassifier().set_fit_request(sample_weight="clf_sample_weight"), +).set_fit_request(sample_weight="meta_clf_sample_weight") +print_routing(est) + +# %% +# However, ``fit`` of the meta-estimator only needs the alias for the +# sub-estimator, since it doesn't validate and route its own required metadata: +est.fit(X, y, sample_weight=my_weights, clf_sample_weight=my_other_weights) + +# %% +# - Alias only on the sub-estimator. This is useful if we don't want the +# meta-estimator to use the metadata, and we only want the metadata to be used +# by the sub-estimator. +est = RouterConsumerClassifier( + estimator=ExampleClassifier().set_fit_request(sample_weight="aliased_sample_weight") +).set_fit_request(sample_weight=True) +print_routing(est) + + +# %% +# Simple Pipeline +# --------------- +# A slightly more complicated use-case is a meta-estimator which does something +# similar to the ``Pipeline``. Here is a meta-estimator, which accepts a +# transformer and a classifier, and applies the transformer before running the +# classifier. + + +class SimplePipeline(ClassifierMixin, BaseEstimator): + _required_parameters = ["estimator"] + + def __init__(self, transformer, classifier): + self.transformer = transformer + self.classifier = classifier + + def fit(self, X, y, **fit_params): + params = process_routing(self, "fit", fit_params) + + self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit) + X_transformed = self.transformer_.transform(X, **params.transformer.transform) + + self.classifier_ = clone(self.classifier).fit( + X_transformed, y, **params.classifier.fit + ) + return self + + def predict(self, X, **predict_params): + params = process_routing(self, "predict", predict_params) + + X_transformed = self.transformer_.transform(X, **params.transformer.transform) + return self.classifier_.predict(X_transformed, **params.classifier.predict) + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add( + transformer=self.transformer, + method_mapping=MethodMapping() + .add(callee="fit", caller="fit") + .add(callee="transform", caller="fit") + .add(callee="transform", caller="predict"), + ) + .add(classifier=self.classifier, method_mapping="one-to-one") + ) + return router + + +# %% +# Note the usage of :class:`~utils.metadata_routing.MethodMapping` to declare +# which methods of the child estimator (callee) are used in which methods of +# the meta estimator (caller). As you can see, we use the transformer's +# ``transform`` and ``fit`` methods in ``fit``, and its ``transform`` method in +# ``predict``, and that's what you see implemented in the routing structure of +# the pipeline class. +# +# Another difference in the above example with the previous ones is the usage +# of :func:`~utils.metadata_routing.process_routing`, which processes the input +# parameters, does the required validation, and returns the `params` which we +# had created in previous examples. This reduces the boilerplate code a +# developer needs to write in each meta-estimator's method. Developers are +# strongly recommended to use this function unless there is a good reason +# against it. +# +# In order to test the above pipeline, let's add an example transformer. + + +class ExampleTransformer(TransformerMixin, BaseEstimator): + def fit(self, X, y, sample_weight=None): + check_metadata(self, sample_weight=sample_weight) + return self + + def transform(self, X, groups=None): + check_metadata(self, groups=groups) + return X + + +# %% +# Now we can test our pipeline, and see if metadata is correctly passed around. +# This example uses our simple pipeline, and our transformer, and our +# consumer+router estimator which uses our simple classifier. + +est = SimplePipeline( + transformer=ExampleTransformer() + # we transformer's fit to receive sample_weight + .set_fit_request(sample_weight=True) + # we want transformer's transform to receive groups + .set_transform_request(groups=True), + classifier=RouterConsumerClassifier( + estimator=ExampleClassifier() + # we want this sub-estimator to receive sample_weight in fit + .set_fit_request(sample_weight=True) + # but not groups in predict + .set_predict_request(groups=False), + ).set_fit_request( + # and we want the meta-estimator to receive sample_weight as well + sample_weight=True + ), +) +est.fit(X, y, sample_weight=my_weights, groups=my_groups).predict( + X[:3], groups=my_groups +) + +# %% +# Deprecation / Default Value Change +# ---------------------------------- +# In this section we show how one should handle the case where a router becomes +# also a consumer, especially when it consumes the same metadata as its +# sub-estimator, or a consumer starts consuming a metadata which it wasn't in +# an older release. In this case, a warning should be raised for a while, to +# let users know the behavior is changed from previous versions. + + +class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + params = process_routing(self, "fit", fit_params) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + + def get_metadata_routing(self): + router = MetadataRouter(owner=self.__class__.__name__).add( + estimator=self.estimator, method_mapping="one-to-one" + ) + return router + + +# %% +# As explained above, this is now a valid usage: + +reg = MetaRegressor(estimator=LinearRegression().set_fit_request(sample_weight=True)) +reg.fit(X, y, sample_weight=my_weights) + + +# %% +# Now imagine we further develop ``MetaRegressor`` and it now also *consumes* +# ``sample_weight``: + + +class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + __metadata_request__fit = {"sample_weight": RequestType.WARN} + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **fit_params): + params = process_routing(self, "fit", fit_params, sample_weight=sample_weight) + check_metadata(self, sample_weight=sample_weight) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add_self(self) + .add(estimator=self.estimator, method_mapping="one-to-one") + ) + return router + + +# %% +# The above implementation is almost no different than ``MetaRegressor``, and +# because of the default request value defined in ``__metadata_request__fit`` +# there is a warning raised. + +with warnings.catch_warnings(record=True) as record: + WeightedMetaRegressor( + estimator=LinearRegression().set_fit_request(sample_weight=False) + ).fit(X, y, sample_weight=my_weights) +for w in record: + print(w.message) + + +# %% +# When an estimator suports a metadata which wasn't supported before, the +# following pattern can be used to warn the users about it. + + +class ExampleRegressor(RegressorMixin, BaseEstimator): + __metadata_request__fit = {"sample_weight": RequestType.WARN} + + def fit(self, X, y, sample_weight=None): + check_metadata(self, sample_weight=sample_weight) + return self + + def predict(self, X): + return np.zeros(shape=(len(X))) + + +with warnings.catch_warnings(record=True) as record: + MetaRegressor(estimator=ExampleRegressor()).fit(X, y, sample_weight=my_weights) +for w in record: + print(w.message) + +# %% +# Third Party Development and scikit-learn Dependency +# --------------------------------------------------- +# +# As seen above, information is communicated between classes using a dictionary +# representation as the output of ``get_metadata_routing``. Therefore it is +# possible for a third party library to not have a hard dependency on +# scikit-learn, and internally (re)implement or vendor the functionality +# required to present the dictionary representation of the routing data to +# other classes inside and outside scikit-learn. diff --git a/sklearn/base.py b/sklearn/base.py index cd88eb6d59f99..3b8000b7c32f5 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -25,6 +25,7 @@ from .utils.validation import _check_feature_names_in from .utils.validation import _generate_get_feature_names_out from .utils.validation import check_is_fitted +from .utils._metadata_requests import _MetadataRequester from .utils.validation import _get_feature_names from .utils._estimator_html_repr import estimator_html_repr from .utils._param_validation import validate_parameter_constraints @@ -86,7 +87,13 @@ def clone(estimator, *, safe=True): new_object_params = estimator.get_params(deep=False) for name, param in new_object_params.items(): new_object_params[name] = clone(param, safe=False) + new_object = klass(**new_object_params) + try: + new_object._metadata_request = copy.deepcopy(estimator._metadata_request) + except AttributeError: + pass + params_set = new_object.get_params(deep=False) # quick sanity check of the parameters of the clone @@ -151,7 +158,7 @@ def _pprint(params, offset=0, printer=repr): return lines -class BaseEstimator: +class BaseEstimator(_MetadataRequester): """Base class for all estimators in scikit-learn. Notes diff --git a/sklearn/inspection/_permutation_importance.py b/sklearn/inspection/_permutation_importance.py index 204dcd9117c77..f97a1f43b9ba9 100644 --- a/sklearn/inspection/_permutation_importance.py +++ b/sklearn/inspection/_permutation_importance.py @@ -15,7 +15,7 @@ def _weights_scorer(scorer, estimator, X, y, sample_weight): if sample_weight is not None: - return scorer(estimator, X, y, sample_weight) + return scorer(estimator, X, y, sample_weight=sample_weight) return scorer(estimator, X, y) diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index e93208f1c67e7..7ac6a7548cc78 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -18,13 +18,13 @@ # Arnaud Joly # License: Simplified BSD +import copy +import warnings from collections.abc import Iterable from functools import partial from collections import Counter import numpy as np -import copy -import warnings from . import ( r2_score, @@ -64,6 +64,11 @@ from ..utils.multiclass import type_of_target from ..base import is_regressor +from ..utils.metadata_routing import _MetadataRequester +from ..utils.metadata_routing import MetadataRequest +from ..utils.metadata_routing import MetadataRouter +from ..utils.metadata_routing import process_routing +from ..utils.metadata_routing import get_routing_for_object def _cached_call(cache, estimator, method, *args, **kwargs): @@ -102,11 +107,15 @@ def __call__(self, estimator, *args, **kwargs): cache = {} if self._use_cache(estimator) else None cached_call = partial(_cached_call, cache) + params = process_routing(self, "score", kwargs) + for name, scorer in self._scorers.items(): if isinstance(scorer, _BaseScorer): - score = scorer._score(cached_call, estimator, *args, **kwargs) + score = scorer._score( + cached_call, estimator, *args, **params.get(name).score + ) else: - score = scorer(estimator, *args, **kwargs) + score = scorer(estimator, *args, **params.get(name).score) scores[name] = score return scores @@ -141,8 +150,24 @@ def _use_cache(self, estimator): return True return False + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. -class _BaseScorer: + Returns + ------- + routing : MetadataRouter + A :class:`~utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + return MetadataRouter(owner=self.__class__.__name__).add( + **self._scorers, method_mapping="score" + ) + + +class _BaseScorer(_MetadataRequester): def __init__(self, score_func, sign, kwargs): self._kwargs = kwargs self._score_func = score_func @@ -194,7 +219,7 @@ def __repr__(self): kwargs_string, ) - def __call__(self, estimator, X, y_true, sample_weight=None): + def __call__(self, estimator, X, y_true, **kwargs): """Evaluate predicted target values for X relative to y_true. Parameters @@ -209,29 +234,68 @@ def __call__(self, estimator, X, y_true, sample_weight=None): y_true : array-like Gold standard target values for X. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer, e.g. sample_weight. + Refer to :func:`set_score_request` for more details. + + .. versionadded:: 1.2 Returns ------- score : float Score function applied to prediction of estimator on X. """ - return self._score( - partial(_cached_call, None), - estimator, - X, - y_true, - sample_weight=sample_weight, - ) + return self._score(partial(_cached_call, None), estimator, X, y_true, **kwargs) def _factory_args(self): """Return non-default make_scorer arguments for repr.""" return "" + def _warn_overlap(self, message, kwargs): + """Warn if there is any overlap between ``self._kwargs`` and kwargs. + + This method is intended to be used to check for overlap between + ``self._kwargs`` and ``kwargs`` passed as metadata. + """ + _kwargs = set() if self._kwargs is None else set(self._kwargs.keys()) + overlap = _kwargs.intersection(kwargs.keys()) + if overlap: + warnings.warn( + f"{message} Overlapping parameters are: {overlap}", UserWarning + ) + + def set_score_request(self, **kwargs): + """Set requested parameters by the scorer. + + Please see :ref:`User Guide ` on how the routing + mechanism works. + + .. versionadded:: 1.2 + + Parameters + ---------- + kwargs : dict + Arguments should be of the form ``param_name=alias``, and `alias` + can be either one of ``{True, False, None, str}`` or an instance of + RequestType. + """ + self._warn_overlap( + message=( + "You are setting metadata request for parameters which are " + "already set as kwargs for this metric. These set values will be " + "overridden by passed metadata if provided. Please pass them either " + "as metadata or kwargs to `make_scorer`." + ), + kwargs=kwargs, + ) + self._metadata_request = MetadataRequest(owner=self.__class__.__name__) + for param, alias in kwargs.items(): + self._metadata_request.score.add_request(param=param, alias=alias) + return self + class _PredictScorer(_BaseScorer): - def _score(self, method_caller, estimator, X, y_true, sample_weight=None): + def _score(self, method_caller, estimator, X, y_true, **kwargs): """Evaluate predicted target values for X relative to y_true. Parameters @@ -250,26 +314,32 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): y_true : array-like Gold standard target values for X. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer, e.g. sample_weight. + Refer to :func:`set_score_request` for more details. + + .. versionadded:: 1.2 Returns ------- score : float Score function applied to prediction of estimator on X. """ - + self._warn_overlap( + message=( + "There is an overlap between set kwargs of this scorer instance and" + " passed metadata. Please pass them either as kwargs to `make_scorer`" + " or metadata, but not both." + ), + kwargs=kwargs, + ) y_pred = method_caller(estimator, "predict", X) - if sample_weight is not None: - return self._sign * self._score_func( - y_true, y_pred, sample_weight=sample_weight, **self._kwargs - ) - else: - return self._sign * self._score_func(y_true, y_pred, **self._kwargs) + scoring_kwargs = {**self._kwargs, **kwargs} + return self._sign * self._score_func(y_true, y_pred, **scoring_kwargs) class _ProbaScorer(_BaseScorer): - def _score(self, method_caller, clf, X, y, sample_weight=None): + def _score(self, method_caller, clf, X, y, **kwargs): """Evaluate predicted probabilities for X relative to y_true. Parameters @@ -289,14 +359,25 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): Gold standard target values for X. These must be class labels, not probabilities. - sample_weight : array-like, default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer, e.g. sample_weight. + Refer to :func:`set_score_request` for more details. + + .. versionadded:: 1.2 Returns ------- score : float Score function applied to prediction of estimator on X. """ + self._warn_overlap( + message=( + "There is an overlap between set kwargs of this scorer instance and" + " passed metadata. Please pass them either as kwargs to `make_scorer`" + " or metadata, but not both." + ), + kwargs=kwargs, + ) y_type = type_of_target(y) y_pred = method_caller(clf, "predict_proba", X) @@ -305,19 +386,22 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): # problem: (when only 2 class are given to `y_true` during scoring) # Thus, we need to check for the shape of `y_pred`. y_pred = self._select_proba_binary(y_pred, clf.classes_) - if sample_weight is not None: - return self._sign * self._score_func( - y, y_pred, sample_weight=sample_weight, **self._kwargs - ) - else: - return self._sign * self._score_func(y, y_pred, **self._kwargs) + + scoring_kwargs = {**self._kwargs, **kwargs} + # this is for backward compatibility to avoid passing sample_weight + # to the scorer if it's None + # TODO(1.3) Probably remove + if scoring_kwargs.get("sample_weight", -1) is None: + del scoring_kwargs["sample_weight"] + + return self._sign * self._score_func(y, y_pred, **scoring_kwargs) def _factory_args(self): return ", needs_proba=True" class _ThresholdScorer(_BaseScorer): - def _score(self, method_caller, clf, X, y, sample_weight=None): + def _score(self, method_caller, clf, X, y, **kwargs): """Evaluate decision function output for X relative to y_true. Parameters @@ -339,14 +423,25 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): Gold standard target values for X. These must be class labels, not decision function values. - sample_weight : array-like, default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer, e.g. sample_weight. + Refer to :func:`set_score_request` for more details. + + .. versionadded:: 1.2 Returns ------- score : float Score function applied to prediction of estimator on X. """ + self._warn_overlap( + message=( + "There is an overlap between set kwargs of this scorer instance and" + " passed metadata. Please pass them either as kwargs to `make_scorer`" + " or metadata, but not both." + ), + kwargs=kwargs, + ) y_type = type_of_target(y) if y_type not in ("binary", "multilabel-indicator"): @@ -377,12 +472,13 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): elif isinstance(y_pred, list): y_pred = np.vstack([p[:, -1] for p in y_pred]).T - if sample_weight is not None: - return self._sign * self._score_func( - y, y_pred, sample_weight=sample_weight, **self._kwargs - ) - else: - return self._sign * self._score_func(y, y_pred, **self._kwargs) + scoring_kwargs = {**self._kwargs, **kwargs} + # this is for backward compatibility to avoid passing sample_weight + # to the scorer if it's None + # TODO(1.3) Probably remove + if scoring_kwargs.get("sample_weight", -1) is None: + del scoring_kwargs["sample_weight"] + return self._sign * self._score_func(y, y_pred, **scoring_kwargs) def _factory_args(self): return ", needs_threshold=True" @@ -425,9 +521,31 @@ def get_scorer(scoring): return scorer -def _passthrough_scorer(estimator, *args, **kwargs): - """Function that wraps estimator.score""" - return estimator.score(*args, **kwargs) +class _PassthroughScorer: + def __init__(self, estimator): + self._estimator = estimator + + def __call__(self, estimator, *args, **kwargs): + """Method that wraps estimator.score""" + return estimator.score(*args, **kwargs) + + def get_metadata_routing(self): + """Get requested data properties. + + .. versionadded:: 1.2 + + Returns + ------- + routing : MetadataRouter + A :class:`~utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + # This scorer doesn't do any validation or routing, it only exposes the + # score requests to the parent object. This object behaves as a + # consumer rather than a router. + res = MetadataRequest(owner=self._estimator.__class__.__name__) + res.score = get_routing_for_object(self._estimator).score + return res def check_scoring(estimator, scoring=None, *, allow_none=False): @@ -482,7 +600,7 @@ def check_scoring(estimator, scoring=None, *, allow_none=False): return get_scorer(scoring) elif scoring is None: if hasattr(estimator, "score"): - return _passthrough_scorer + return _PassthroughScorer(estimator) elif allow_none: return None else: @@ -615,7 +733,7 @@ def make_scorer( ---------- score_func : callable Score function (or loss function) with signature - `score_func(y, y_pred, **kwargs)`. + ``score_func(y, y_pred, **kwargs)``. greater_is_better : bool, default=True Whether `score_func` is a score function (default), meaning high is diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 204a895742db7..d6c4eee13e910 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -15,6 +15,8 @@ from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_equal from sklearn.utils._testing import ignore_warnings +from sklearn.utils.metadata_routing import MetadataRouter, RequestType +from sklearn.tests.test_metadata_routing import assert_request_is_empty from sklearn.base import BaseEstimator from sklearn.metrics import ( @@ -37,7 +39,7 @@ from sklearn.metrics import check_scoring from sklearn.metrics._scorer import ( _PredictScorer, - _passthrough_scorer, + _PassthroughScorer, _MultimetricScorer, _check_multimetric_scoring, ) @@ -238,7 +240,7 @@ def check_scoring_validator_for_single_metric_usecases(scoring_validator): estimator = EstimatorWithFitAndScore() estimator.fit([[1]], [1]) scorer = scoring_validator(estimator) - assert scorer is _passthrough_scorer + assert isinstance(scorer, _PassthroughScorer) assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0) estimator = EstimatorWithFitAndPredict() @@ -632,11 +634,14 @@ def test_classification_scorer_sample_weight(): else: target = y_test try: + scorer = scorer.set_score_request(sample_weight=True) weighted = scorer( estimator[name], X_test, target, sample_weight=sample_weight ) ignored = scorer(estimator[name], X_test[10:], target[10:]) unweighted = scorer(estimator[name], X_test, target) + # this should not raise. sample_weight should be ignored if None. + _ = scorer(estimator[name], X_test[:10], target[:10], sample_weight=None) assert weighted != unweighted, ( f"scorer {name} behaves identically when called with " f"sample weights: {weighted} vs {unweighted}" @@ -1155,3 +1160,71 @@ def test_scorer_no_op_multiclass_select_proba(): labels=lr.classes_, ) scorer(lr, X_test, y_test) + + +@pytest.mark.parametrize("name", get_scorer_names(), ids=get_scorer_names()) +def test_scorer_metadata_request(name): + scorer = get_scorer(name) + assert hasattr(scorer, "set_score_request") + assert hasattr(scorer, "get_metadata_routing") + + assert_request_is_empty(scorer.get_metadata_routing()) + + weighted_scorer = scorer.set_score_request(sample_weight=True) + # set_score_request should mutate the instance + assert weighted_scorer is scorer + + assert_request_is_empty(weighted_scorer.get_metadata_routing(), exclude="score") + assert ( + weighted_scorer.get_metadata_routing().score.requests["sample_weight"] + == RequestType.REQUESTED + ) + + # make sure putting the scorer in a router doesn't request anything + router = MetadataRouter(owner="test").add( + method_mapping="score", scorer=get_scorer(name) + ) + with pytest.raises(TypeError, match="got unexpected argument"): + router.validate_metadata(params={"sample_weight": 1}, method="score") + routed_params = router.route_params(params={"sample_weight": 1}, caller="score") + assert not routed_params.scorer.score + + # make sure putting weighted_scorer in a router requests sample_weight + router = MetadataRouter(owner="test").add( + scorer=weighted_scorer, method_mapping="score" + ) + router.validate_metadata(params={"sample_weight": 1}, method="score") + routed_params = router.route_params(params={"sample_weight": 1}, caller="score") + assert list(routed_params.scorer.score.keys()) == ["sample_weight"] + + +def test_metadata_kwarg_conflict(): + X, y = make_classification( + n_classes=3, n_informative=3, n_samples=20, random_state=0 + ) + lr = LogisticRegression().fit(X, y) + + scorer = make_scorer( + roc_auc_score, + needs_proba=True, + multi_class="ovo", + labels=lr.classes_, + ) + with pytest.warns(UserWarning, match="already set as kwargs"): + scorer.set_score_request(labels=True) + + with pytest.warns(UserWarning, match="There is an overlap"): + scorer(lr, X, y, labels=lr.classes_) + + +def test_PassthroughScorer_metadata_request(): + scorer = _PassthroughScorer( + estimator=LinearSVC() + .set_score_request(sample_weight="alias") + .set_fit_request(sample_weight=True) + ) + # test that _PassthroughScorer leaves everything other than `score` empty + assert_request_is_empty(scorer.get_metadata_routing(), exclude="score") + # test that _PassthroughScorer doesn't behave like a router and leaves + # the request as is. + assert scorer.get_metadata_routing().score.requests["sample_weight"] == "alias" diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index d2a0b5e1fc329..bfb04abf66955 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -28,6 +28,8 @@ from ..utils.validation import _num_samples, column_or_1d from ..utils.validation import check_array from ..utils.multiclass import type_of_target +from ..utils.metadata_routing import RequestType +from ..utils.metadata_routing import _MetadataRequester from ..base import _pprint __all__ = [ @@ -51,12 +53,30 @@ ] -class BaseCrossValidator(metaclass=ABCMeta): +class GroupsComsumerMixin(_MetadataRequester): + """A Mixin to ``groups`` by default. + + This Mixin makes the object to request ``groups`` by default as + ``REQUESTED``. + + .. versionadded:: 1.2 + """ + + __metadata_request__split = {"groups": RequestType.REQUESTED} + + +class BaseCrossValidator(_MetadataRequester, metaclass=ABCMeta): """Base class for all cross-validators Implementations must define `_iter_test_masks` or `_iter_test_indices`. """ + # This indicates that by default CV splitters don't have a "groups" kwarg, + # unless indicated by inheriting from ``GroupsComsumerMixin``. + # This also prevents ``set_split_request`` to be generated for splitters + # which don't support ``groups``. + __metadata_request__split = {"groups": RequestType.UNUSED} + def split(self, X, y=None, groups=None): """Generate indices to split data into training and test set. @@ -450,7 +470,7 @@ def _iter_test_indices(self, X, y=None, groups=None): current = stop -class GroupKFold(_BaseKFold): +class GroupKFold(GroupsComsumerMixin, _BaseKFold): """K-fold iterator variant with non-overlapping groups. The same group will not appear in two different folds (the number of @@ -752,7 +772,7 @@ def split(self, X, y, groups=None): return super().split(X, y, groups) -class StratifiedGroupKFold(_BaseKFold): +class StratifiedGroupKFold(GroupsComsumerMixin, _BaseKFold): """Stratified K-Folds iterator variant with non-overlapping groups. This cross-validation object is a variation of StratifiedKFold attempts to @@ -1102,7 +1122,7 @@ def split(self, X, y=None, groups=None): ) -class LeaveOneGroupOut(BaseCrossValidator): +class LeaveOneGroupOut(GroupsComsumerMixin, BaseCrossValidator): """Leave One Group Out cross-validator Provides train/test indices to split data according to a third-party @@ -1219,7 +1239,7 @@ def split(self, X, y=None, groups=None): return super().split(X, y, groups) -class LeavePGroupsOut(BaseCrossValidator): +class LeavePGroupsOut(GroupsComsumerMixin, BaseCrossValidator): """Leave P Group(s) Out cross-validator Provides train/test indices to split data according to a third-party @@ -1353,7 +1373,7 @@ def split(self, X, y=None, groups=None): return super().split(X, y, groups) -class _RepeatedSplits(metaclass=ABCMeta): +class _RepeatedSplits(_MetadataRequester, metaclass=ABCMeta): """Repeated splits for an arbitrary randomized CV splitter. Repeats splits for cross-validators n times with different randomization @@ -1377,6 +1397,12 @@ class _RepeatedSplits(metaclass=ABCMeta): and shuffle. """ + # This indicates that by default CV splitters don't have a "groups" kwarg, + # unless indicated by inheriting from ``GroupsComsumerMixin``. + # This also prevents ``set_split_request`` to be generated for splitters + # which don't support ``groups``. + __metadata_request__split = {"groups": RequestType.UNUSED} + def __init__(self, cv, *, n_repeats=10, random_state=None, **cvargs): if not isinstance(n_repeats, numbers.Integral): raise ValueError("Number of repetitions must be of Integral type.") @@ -1567,9 +1593,15 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None): ) -class BaseShuffleSplit(metaclass=ABCMeta): +class BaseShuffleSplit(_MetadataRequester, metaclass=ABCMeta): """Base class for ShuffleSplit and StratifiedShuffleSplit""" + # This indicates that by default CV splitters don't have a "groups" kwarg, + # unless indicated by inheriting from ``GroupsComsumerMixin``. + # This also prevents ``set_split_request`` to be generated for splitters + # which don't support ``groups``. + __metadata_request__split = {"groups": RequestType.UNUSED} + def __init__( self, n_splits=10, *, test_size=None, train_size=None, random_state=None ): @@ -1734,7 +1766,7 @@ def _iter_indices(self, X, y=None, groups=None): yield ind_train, ind_test -class GroupShuffleSplit(ShuffleSplit): +class GroupShuffleSplit(GroupsComsumerMixin, ShuffleSplit): """Shuffle-Group(s)-Out cross-validation iterator Provides randomized train/test indices to split data according to a diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index e1466d69d3902..84b951f7d313c 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -1900,7 +1900,13 @@ def _run_search(self, evaluate): attr[0].islower() and attr[-1:] == "_" and attr - not in {"cv_results_", "best_estimator_", "refit_time_", "classes_"} + not in { + "cv_results_", + "best_estimator_", + "refit_time_", + "classes_", + "scorer_", + } ): assert getattr(gscv, attr) == getattr(mycv, attr), ( "Attribute %s not equal" % attr diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index f502ebc8a3b6a..fb0b467e01bca 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -48,6 +48,32 @@ from sklearn.svm import SVC +from sklearn.utils.metadata_routing import RequestType +from sklearn.tests.test_metadata_routing import assert_request_is_empty + +NO_GROUP_SPLITTERS = [ + KFold(), + StratifiedKFold(), + TimeSeriesSplit(), + LeaveOneOut(), + LeavePOut(p=2), + ShuffleSplit(), + StratifiedShuffleSplit(test_size=0.5), + PredefinedSplit([1, 1, 2, 2]), + RepeatedKFold(), + RepeatedStratifiedKFold(), +] + +GROUP_SPLITTERS = [ + GroupKFold(), + LeavePGroupsOut(n_groups=1), + StratifiedGroupKFold(), + LeaveOneGroupOut(), + GroupShuffleSplit(), +] + +ALL_SPLITTERS = NO_GROUP_SPLITTERS + GROUP_SPLITTERS # type: ignore + X = np.ones(10) y = np.arange(10) // 2 P_sparse = coo_matrix(np.eye(5)) @@ -1792,7 +1818,11 @@ def test_nested_cv(): cvs = [ LeaveOneGroupOut(), StratifiedKFold(n_splits=2), + LeaveOneOut(), GroupKFold(n_splits=3), + StratifiedKFold(), + StratifiedGroupKFold(), + StratifiedShuffleSplit(n_splits=3, random_state=0), ] for inner_cv, outer_cv in combinations_with_replacement(cvs, 2): @@ -1921,3 +1951,25 @@ def test_random_state_shuffle_false(Klass): ) def test_yields_constant_splits(cv, expected): assert _yields_constant_splits(cv) == expected + + +@pytest.mark.parametrize("cv", ALL_SPLITTERS, ids=[str(cv) for cv in ALL_SPLITTERS]) +def test_splitter_get_metadata_routing(cv): + """Check get_metadata_routing returns the correct MetadataRouter.""" + assert hasattr(cv, "get_metadata_routing") + metadata = cv.get_metadata_routing() + if cv in GROUP_SPLITTERS: + assert metadata.split.requests["groups"] == RequestType.REQUESTED + elif cv in NO_GROUP_SPLITTERS: + assert not metadata.split.requests + + assert_request_is_empty(metadata, exclude=["split"]) + + +@pytest.mark.parametrize("cv", ALL_SPLITTERS, ids=[str(cv) for cv in ALL_SPLITTERS]) +def test_splitter_set_split_request(cv): + """Check set_split_request is defined for group splitters and not for others.""" + if cv in GROUP_SPLITTERS: + assert hasattr(cv, "set_split_request") + elif cv in NO_GROUP_SPLITTERS: + assert not hasattr(cv, "set_split_request") diff --git a/sklearn/tests/test_metadata_routing.py b/sklearn/tests/test_metadata_routing.py new file mode 100644 index 0000000000000..32d2c4e21551a --- /dev/null +++ b/sklearn/tests/test_metadata_routing.py @@ -0,0 +1,846 @@ +""" +Metadata Routing Utility Tests +""" + +# Author: Adrin Jalali +# License: BSD 3 clause + +import re +import numpy as np +import pytest + +from sklearn.base import BaseEstimator +from sklearn.base import ClassifierMixin +from sklearn.base import RegressorMixin +from sklearn.base import TransformerMixin +from sklearn.base import MetaEstimatorMixin +from sklearn.base import clone +from sklearn.linear_model import LinearRegression +from sklearn.utils.validation import check_is_fitted +from sklearn.utils.metadata_routing import RequestType +from sklearn.utils.metadata_routing import MetadataRequest +from sklearn.utils.metadata_routing import get_routing_for_object +from sklearn.utils.metadata_routing import MetadataRouter +from sklearn.utils.metadata_routing import MethodMapping +from sklearn.utils.metadata_routing import process_routing +from sklearn.utils._metadata_requests import MethodMetadataRequest +from sklearn.utils._metadata_requests import _MetadataRequester +from sklearn.utils._metadata_requests import METHODS + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +my_groups = np.random.randint(0, 10, size=N) +my_weights = np.random.rand(N) +my_other_weights = np.random.rand(N) + + +def assert_request_is_empty(metadata_request, exclude=None): + """Check if a metadata request dict is empty. + + One can exclude a method or a list of methods from the check using the + ``exclude`` perameter. + """ + if isinstance(metadata_request, MetadataRouter): + for _, route_mapping in metadata_request: + assert_request_is_empty(route_mapping.router) + return + + exclude = [] if exclude is None else exclude + for method in METHODS: + if method in exclude: + continue + mmr = getattr(metadata_request, method) + props = [ + prop + for prop, alias in mmr.requests.items() + if isinstance(alias, str) + or RequestType(alias) != RequestType.ERROR_IF_PASSED + ] + assert not len(props) + + +def assert_request_equal(request, dictionary): + for method, requests in dictionary.items(): + mmr = getattr(request, method) + assert mmr.requests == requests + + empty_methods = [method for method in METHODS if method not in dictionary] + for method in empty_methods: + assert not len(getattr(request, method).requests) + + +def record_metadata(obj, method, **kwargs): + """Utility function to store passed metadata to a method.""" + if not hasattr(obj, "_records"): + setattr(obj, "_records", dict()) + obj._records[method] = kwargs + + +def check_recorded_metadata(obj, method, **kwargs): + """Check whether the expected metadata is passed to the object's method.""" + records = getattr(obj, "_records", dict()).get(method, dict()) + assert set(kwargs.keys()) == set(records.keys()) + for key, value in kwargs.items(): + assert records[key] is value + + +class MetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + """A meta-regressor which is only a router.""" + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, **fit_params): + params = process_routing(self, "fit", fit_params) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + + def get_metadata_routing(self): + router = MetadataRouter(owner=self.__class__.__name__).add( + estimator=self.estimator, method_mapping="one-to-one" + ) + return router + + +class RegressorMetadata(RegressorMixin, BaseEstimator): + """A regressor consuming a metadata.""" + + def fit(self, X, y, sample_weight=None): + record_metadata(self, "fit", sample_weight=sample_weight) + return self + + def predict(self, X): + return np.zeros(shape=(len(X))) + + +class WeightedMetaRegressor(MetaEstimatorMixin, RegressorMixin, BaseEstimator): + """A meta-regressor which is also a consumer.""" + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **fit_params): + record_metadata(self, "fit", sample_weight=sample_weight) + params = process_routing(self, "fit", fit_params, sample_weight=sample_weight) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + return self + + def predict(self, X, **predict_params): + params = process_routing(self, "predict", predict_params) + return self.estimator_.predict(X, **params.estimator.predict) + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add_self(self) + .add(estimator=self.estimator, method_mapping="one-to-one") + ) + return router + + +class ClassifierNoMetadata(ClassifierMixin, BaseEstimator): + """An estimator which accepts no metadata on any method.""" + + def fit(self, X, y): + return self + + def predict(self, X): + return np.ones(len(X)) + + +class ClassifierFitMetadata(ClassifierMixin, BaseEstimator): + """An estimator accepting two metadata in its ``fit`` method.""" + + def fit(self, X, y, sample_weight=None, brand=None): + record_metadata(self, "fit", sample_weight=sample_weight, brand=brand) + return self + + def predict(self, X): + return np.ones(len(X)) + + +class SimpleMetaClassifier(MetaEstimatorMixin, ClassifierMixin, BaseEstimator): + """A meta-estimator which also consumes sample_weight itself in ``fit``.""" + + def __init__(self, estimator): + self.estimator = estimator + + def fit(self, X, y, sample_weight=None, **kwargs): + record_metadata(self, "fit", sample_weight=sample_weight) + params = process_routing(self, "fit", kwargs, sample_weight=sample_weight) + self.estimator_ = clone(self.estimator).fit(X, y, **params.estimator.fit) + return self + + def get_metadata_routing(self): + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add_self(self) + .add(estimator=self.estimator, method_mapping="fit") + ) + return router + + +class TransformerMetadata(TransformerMixin, BaseEstimator): + """A transformer which accepts metadata on fit and transform.""" + + def fit(self, X, y=None, brand=None, sample_weight=None): + record_metadata(self, "fit", brand=brand, sample_weight=sample_weight) + return self + + def transform(self, X, sample_weight=None): + record_metadata(self, "transform", sample_weight=sample_weight) + return X + + +class MetaTransformer(MetaEstimatorMixin, TransformerMixin, BaseEstimator): + """A simple meta-transformer.""" + + def __init__(self, transformer): + self.transformer = transformer + + def fit(self, X, y=None, **fit_params): + params = process_routing(self, "fit", fit_params) + self.transformer_ = clone(self.transformer).fit(X, y, **params.transformer.fit) + return self + + def transform(self, X, y=None, **transform_params): + params = process_routing(self, "transform", transform_params) + return self.transformer_.transform(X, **params.transformer.transform) + + def get_metadata_routing(self): + return MetadataRouter(owner=self.__class__.__name__).add( + transformer=self.transformer, method_mapping="one-to-one" + ) + + +class SimplePipeline(BaseEstimator): + """A very simple pipeline, assuming the last step is always a predictor.""" + + def __init__(self, steps): + self.steps = steps + + def fit(self, X, y, **fit_params): + self.steps_ = [] + params = process_routing(self, "fit", fit_params) + X_transformed = X + for i, step in enumerate(self.steps[:-1]): + transformer = clone(step).fit( + X_transformed, y, **params.get(f"step_{i}").fit + ) + self.steps_.append(transformer) + X_transformed = transformer.transform( + X_transformed, **params.get(f"step_{i}").transform + ) + + self.steps_.append( + clone(self.steps[-1]).fit(X_transformed, y, **params.predictor.fit) + ) + return self + + def predict(self, X, **predict_params): + check_is_fitted(self) + X_transformed = X + params = process_routing(self, "predict", predict_params) + for i, step in enumerate(self.steps_[:-1]): + X_transformed = step.transform(X, **params.get(f"step_{i}").transform) + + return self.steps_[-1].predict(X_transformed, **params.predictor.predict) + + def get_metadata_routing(self): + router = MetadataRouter(owner=self.__class__.__name__) + for i, step in enumerate(self.steps[:-1]): + router.add( + **{f"step_{i}": step}, + method_mapping=MethodMapping() + .add(callee="fit", caller="fit") + .add(callee="transform", caller="fit") + .add(callee="transform", caller="predict"), + ) + router.add(predictor=self.steps[-1], method_mapping="one-to-one") + return router + + +def test_assert_request_is_empty(): + requests = MetadataRequest(owner="test") + assert_request_is_empty(requests) + + requests.fit.add_request(param="foo", alias=RequestType.ERROR_IF_PASSED) + # this should still work, since ERROR_IF_PASSED is the default value + assert_request_is_empty(requests) + + requests.fit.add_request(param="bar", alias="value") + with pytest.raises(AssertionError): + # now requests is no more empty + assert_request_is_empty(requests) + + # but one can exclude a method + assert_request_is_empty(requests, exclude="fit") + + requests.score.add_request(param="carrot", alias=RequestType.REQUESTED) + with pytest.raises(AssertionError): + # excluding `fit` is not enough + assert_request_is_empty(requests, exclude="fit") + + # and excluding both fit and score would avoid an exception + assert_request_is_empty(requests, exclude=["fit", "score"]) + + # test if a router is empty + assert_request_is_empty( + MetadataRouter(owner="test") + .add_self(WeightedMetaRegressor(estimator=None)) + .add(method_mapping="fit", estimator=RegressorMetadata()) + ) + + +def test_default_requests(): + class OddEstimator(BaseEstimator): + __metadata_request__fit = { + # set a different default request + "sample_weight": RequestType.REQUESTED + } # type: ignore + + odd_request = get_routing_for_object(OddEstimator()) + assert odd_request.fit.requests == {"sample_weight": RequestType.REQUESTED} + + # check other test estimators + assert not len(get_routing_for_object(ClassifierNoMetadata()).fit.requests) + assert_request_is_empty(ClassifierNoMetadata().get_metadata_routing()) + + trs_request = get_routing_for_object(TransformerMetadata()) + assert trs_request.fit.requests == { + "sample_weight": RequestType(None), + "brand": RequestType(None), + } + assert trs_request.transform.requests == { + "sample_weight": RequestType(None), + } + assert_request_is_empty(trs_request) + + est_request = get_routing_for_object(ClassifierFitMetadata()) + assert est_request.fit.requests == { + "sample_weight": RequestType(None), + "brand": RequestType(None), + } + assert_request_is_empty(est_request) + + +def test_simple_metadata_routing(): + # Tests that metadata is properly routed + + # The underlying estimator doesn't accept or request metadata + clf = SimpleMetaClassifier(estimator=ClassifierNoMetadata()) + clf.fit(X, y) + + # Meta-estimator consumes sample_weight, but doesn't forward it to the underlying + # estimator + clf = SimpleMetaClassifier(estimator=ClassifierNoMetadata()) + clf.fit(X, y, sample_weight=my_weights) + + # If the estimator accepts the metadata but doesn't explicitly say it doesn't + # need it, there's an error + clf = SimpleMetaClassifier(estimator=ClassifierFitMetadata()) + with pytest.raises( + ValueError, + match=( + "sample_weight is passed but is not explicitly set as requested or not for" + " ClassifierFitMetadata.fit" + ), + ): + clf.fit(X, y, sample_weight=my_weights) + + # Explicitly saying the estimator doesn't need it, makes the error go away, + # because in this case `SimpleMetaClassifier` consumes `sample_weight`. If + # there was no consumer of sample_weight, passing it would result in an + # error. + clf = SimpleMetaClassifier( + estimator=ClassifierFitMetadata().set_fit_request(sample_weight=False) + ) + # this doesn't raise since SimpleMetaClassifier itself is a consumer, + # and passing metadata to the consumer directly is fine regardless of its + # metadata_request values. + clf.fit(X, y, sample_weight=my_weights) + check_recorded_metadata(clf.estimator_, "fit", sample_weight=None, brand=None) + + # Requesting a metadata will make the meta-estimator forward it correctly + clf = SimpleMetaClassifier( + estimator=ClassifierFitMetadata().set_fit_request(sample_weight=True) + ) + clf.fit(X, y, sample_weight=my_weights) + check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights, brand=None) + + # And requesting it with an alias + clf = SimpleMetaClassifier( + estimator=ClassifierFitMetadata().set_fit_request( + sample_weight="alternative_weight" + ) + ) + clf.fit(X, y, alternative_weight=my_weights) + check_recorded_metadata(clf.estimator_, "fit", sample_weight=my_weights, brand=None) + + +def test_nested_routing(): + # check if metadata is routed in a nested routing situation. + pipeline = SimplePipeline( + [ + MetaTransformer( + transformer=TransformerMetadata() + .set_fit_request(brand=True, sample_weight=False) + .set_transform_request(sample_weight=True) + ), + WeightedMetaRegressor( + estimator=RegressorMetadata().set_fit_request( + sample_weight="inner_weights" + ) + ).set_fit_request(sample_weight="outer_weights"), + ] + ) + w1, w2, w3 = [1], [2], [3] + pipeline.fit( + X, y, brand=my_groups, sample_weight=w1, outer_weights=w2, inner_weights=w3 + ) + check_recorded_metadata( + pipeline.steps_[0].transformer_, "fit", brand=my_groups, sample_weight=None + ) + check_recorded_metadata( + pipeline.steps_[0].transformer_, "transform", sample_weight=w1 + ) + check_recorded_metadata(pipeline.steps_[1], "fit", sample_weight=w2) + check_recorded_metadata(pipeline.steps_[1].estimator_, "fit", sample_weight=w3) + + pipeline.predict(X, sample_weight=w3) + check_recorded_metadata( + pipeline.steps_[0].transformer_, "transform", sample_weight=w3 + ) + + +def test_nested_routing_conflict(): + # check if an error is raised if there's a conflict between keys + pipeline = SimplePipeline( + [ + MetaTransformer( + transformer=TransformerMetadata() + .set_fit_request(brand=True, sample_weight=False) + .set_transform_request(sample_weight=True) + ), + WeightedMetaRegressor( + estimator=RegressorMetadata().set_fit_request(sample_weight=True) + ).set_fit_request(sample_weight="outer_weights"), + ] + ) + w1, w2 = [1], [2] + with pytest.raises( + ValueError, + match=( + re.escape( + "In WeightedMetaRegressor, there is a conflict on sample_weight between" + " what is requested for this estimator and what is requested by its" + " children. You can resolve this conflict by using an alias for the" + " child estimator(s) requested metadata." + ) + ), + ): + pipeline.fit(X, y, brand=my_groups, sample_weight=w1, outer_weights=w2) + + +def test_invalid_metadata(): + # check that passing wrong metadata raises an error + trs = MetaTransformer( + transformer=TransformerMetadata().set_transform_request(sample_weight=True) + ) + with pytest.raises( + TypeError, + match=(re.escape("transform got unexpected argument(s) {'other_param'}")), + ): + trs.fit(X, y).transform(X, other_param=my_weights) + + # passing a metadata which is not requested by any estimator should also raise + trs = MetaTransformer( + transformer=TransformerMetadata().set_transform_request(sample_weight=False) + ) + with pytest.raises( + TypeError, + match=(re.escape("transform got unexpected argument(s) {'sample_weight'}")), + ): + trs.fit(X, y).transform(X, sample_weight=my_weights) + + +def test_get_metadata_routing(): + class TestDefaultsBadMethodName(_MetadataRequester): + __metadata_request__fit = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": RequestType.ERROR_IF_PASSED, + } + __metadata_request__score = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": True, + "my_other_param": RequestType.ERROR_IF_PASSED, + } + # this will raise an error since we don't understand "other_method" as a method + __metadata_request__other_method = {"my_param": True} + + class TestDefaults(_MetadataRequester): + __metadata_request__fit = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_other_param": RequestType.ERROR_IF_PASSED, + } + __metadata_request__score = { + "sample_weight": RequestType.ERROR_IF_PASSED, + "my_param": True, + "my_other_param": RequestType.ERROR_IF_PASSED, + } + __metadata_request__predict = {"my_param": True} + + with pytest.raises( + AttributeError, match="'MetadataRequest' object has no attribute 'other_method'" + ): + TestDefaultsBadMethodName().get_metadata_routing() + + expected = { + "score": { + "my_param": RequestType.REQUESTED, + "my_other_param": RequestType.ERROR_IF_PASSED, + "sample_weight": RequestType.ERROR_IF_PASSED, + }, + "fit": { + "my_other_param": RequestType.ERROR_IF_PASSED, + "sample_weight": RequestType.ERROR_IF_PASSED, + }, + "predict": {"my_param": RequestType.REQUESTED}, + } + assert_request_equal(TestDefaults().get_metadata_routing(), expected) + + est = TestDefaults().set_score_request(my_param="other_param") + expected = { + "score": { + "my_param": "other_param", + "my_other_param": RequestType.ERROR_IF_PASSED, + "sample_weight": RequestType.ERROR_IF_PASSED, + }, + "fit": { + "my_other_param": RequestType.ERROR_IF_PASSED, + "sample_weight": RequestType.ERROR_IF_PASSED, + }, + "predict": {"my_param": RequestType.REQUESTED}, + } + assert_request_equal(est.get_metadata_routing(), expected) + + est = TestDefaults().set_fit_request(sample_weight=True) + expected = { + "score": { + "my_param": RequestType.REQUESTED, + "my_other_param": RequestType.ERROR_IF_PASSED, + "sample_weight": RequestType.ERROR_IF_PASSED, + }, + "fit": { + "my_other_param": RequestType.ERROR_IF_PASSED, + "sample_weight": RequestType.REQUESTED, + }, + "predict": {"my_param": RequestType.REQUESTED}, + } + assert_request_equal(est.get_metadata_routing(), expected) + + +def test_setting_default_requests(): + # Test _get_default_requests method + test_cases = dict() + + class ExplicitRequest(BaseEstimator): + # `fit` doesn't accept `props` explicitly, but we want to request it + __metadata_request__fit = {"prop": RequestType.ERROR_IF_PASSED} + + def fit(self, X, y, **kwargs): + return self + + test_cases[ExplicitRequest] = {"prop": RequestType.ERROR_IF_PASSED} + + class ExplicitRequestOverwrite(BaseEstimator): + # `fit` explicitly accepts `props`, but we want to change the default + # request value from ERROR_IF_PASSEd to REQUESTED + __metadata_request__fit = {"prop": RequestType.REQUESTED} + + def fit(self, X, y, prop=None, **kwargs): + return self + + test_cases[ExplicitRequestOverwrite] = {"prop": RequestType.REQUESTED} + + class ImplicitRequest(BaseEstimator): + # `fit` requests `prop` and the default ERROR_IF_PASSED should be used + def fit(self, X, y, prop=None, **kwargs): + return self + + test_cases[ImplicitRequest] = {"prop": RequestType.ERROR_IF_PASSED} + + class ImplicitRequestRemoval(BaseEstimator): + # `fit` (in this class or a parent) requests `prop`, but we don't want + # it requested at all. + __metadata_request__fit = {"prop": RequestType.UNUSED} + + def fit(self, X, y, prop=None, **kwargs): + return self + + test_cases[ImplicitRequestRemoval] = {} + + for Klass, requests in test_cases.items(): + assert get_routing_for_object(Klass()).fit.requests == requests + assert_request_is_empty(Klass().get_metadata_routing(), exclude="fit") + Klass().fit(None, None) # for coverage + + +def test_method_metadata_request(): + mmr = MethodMetadataRequest(owner="test", method="fit") + + with pytest.raises( + ValueError, match="alias should be either a valid identifier or" + ): + mmr.add_request(param="foo", alias=1.4) + + mmr.add_request(param="foo", alias=None) + assert mmr.requests == {"foo": RequestType.ERROR_IF_PASSED} + mmr.add_request(param="foo", alias=False) + assert mmr.requests == {"foo": RequestType.UNREQUESTED} + mmr.add_request(param="foo", alias=True) + assert mmr.requests == {"foo": RequestType.REQUESTED} + mmr.add_request(param="foo", alias="foo") + assert mmr.requests == {"foo": RequestType.REQUESTED} + mmr.add_request(param="foo", alias="bar") + assert mmr.requests == {"foo": "bar"} + assert mmr._get_param_names(return_alias=False) == {"foo"} + assert mmr._get_param_names(return_alias=True) == {"bar"} + + +def test_get_routing_for_object(): + class Consumer(BaseEstimator): + __metadata_request__fit = {"prop": RequestType.ERROR_IF_PASSED} + + assert_request_is_empty(get_routing_for_object(None)) + assert_request_is_empty(get_routing_for_object(object())) + + mr = MetadataRequest(owner="test") + mr.fit.add_request(param="foo", alias="bar") + mr_factory = get_routing_for_object(mr) + assert_request_is_empty(mr_factory, exclude="fit") + assert mr_factory.fit.requests == {"foo": "bar"} + + mr = get_routing_for_object(Consumer()) + assert_request_is_empty(mr, exclude="fit") + assert mr.fit.requests == {"prop": RequestType.ERROR_IF_PASSED} + + +def test_metaestimator_warnings(): + class WeightedMetaRegressorWarn(WeightedMetaRegressor): + __metadata_request__fit = {"sample_weight": RequestType.WARN} + + with pytest.warns( + UserWarning, match="Support for .* has recently been added to this class" + ): + WeightedMetaRegressorWarn( + estimator=LinearRegression().set_fit_request(sample_weight=False) + ).fit(X, y, sample_weight=my_weights) + + +def test_estimator_warnings(): + class RegressorMetadataWarn(RegressorMetadata): + __metadata_request__fit = {"sample_weight": RequestType.WARN} + + with pytest.warns( + UserWarning, match="Support for .* has recently been added to this class" + ): + MetaRegressor(estimator=RegressorMetadataWarn()).fit( + X, y, sample_weight=my_weights + ) + + +@pytest.mark.parametrize( + "obj, string", + [ + ( + MethodMetadataRequest(owner="test", method="fit").add_request( + param="foo", alias="bar" + ), + "{'foo': 'bar'}", + ), + ( + MetadataRequest(owner="test"), + "{}", + ), + (MethodMapping.from_str("score"), "[{'callee': 'score', 'caller': 'score'}]"), + ( + MetadataRouter(owner="test").add( + method_mapping="predict", estimator=RegressorMetadata() + ), + "{'estimator': {'mapping': [{'callee': 'predict', 'caller': 'predict'}]," + " 'router': {'fit': {'sample_weight': }," + " 'score': {'sample_weight': }}}}", + ), + ], +) +def test_string_representations(obj, string): + assert str(obj) == string + + +@pytest.mark.parametrize( + "obj, method, inputs, err_cls, err_msg", + [ + ( + MethodMapping(), + "add", + {"callee": "invalid", "caller": "fit"}, + ValueError, + "Given callee", + ), + ( + MethodMapping(), + "add", + {"callee": "fit", "caller": "invalid"}, + ValueError, + "Given caller", + ), + ( + MethodMapping, + "from_str", + {"route": "invalid"}, + ValueError, + "route should be 'one-to-one' or a single method!", + ), + ( + MetadataRouter(owner="test"), + "add_self", + {"obj": MetadataRouter(owner="test")}, + ValueError, + "Given `obj` is neither a `MetadataRequest` nor does it implement", + ), + ( + ClassifierFitMetadata(), + "set_fit_request", + {"invalid": True}, + TypeError, + "Unexpected args", + ), + ], +) +def test_validations(obj, method, inputs, err_cls, err_msg): + with pytest.raises(err_cls, match=err_msg): + getattr(obj, method)(**inputs) + + +def test_methodmapping(): + mm = ( + MethodMapping() + .add(caller="fit", callee="transform") + .add(caller="fit", callee="fit") + ) + + mm_list = list(mm) + assert mm_list[0] == ("transform", "fit") + assert mm_list[1] == ("fit", "fit") + + mm = MethodMapping.from_str("one-to-one") + assert ( + str(mm) + == "[{'callee': 'fit', 'caller': 'fit'}, {'callee': 'partial_fit', 'caller':" + " 'partial_fit'}, {'callee': 'predict', 'caller': 'predict'}, {'callee':" + " 'score', 'caller': 'score'}, {'callee': 'split', 'caller': 'split'}," + " {'callee': 'transform', 'caller': 'transform'}, {'callee':" + " 'inverse_transform', 'caller': 'inverse_transform'}]" + ) + + mm = MethodMapping.from_str("score") + assert repr(mm) == "[{'callee': 'score', 'caller': 'score'}]" + + +def test_metadatarouter_add_self(): + # adding a MetadataRequest as `self` adds a copy + request = MetadataRequest(owner="nested") + request.fit.add_request(param="param", alias=True) + router = MetadataRouter(owner="test").add_self(request) + assert str(router._self) == str(request) + # should be a copy, not the same object + assert router._self is not request + + # one can add an estimator as self + est = RegressorMetadata().set_fit_request(sample_weight="my_weights") + router = MetadataRouter(owner="test").add_self(obj=est) + assert str(router._self) == str(est.get_metadata_routing()) + assert router._self is not est.get_metadata_routing() + + # adding a consumer+router as self should only add the consumer part + est = WeightedMetaRegressor( + estimator=RegressorMetadata().set_fit_request(sample_weight="nested_weights") + ) + router = MetadataRouter(owner="test").add_self(obj=est) + # _get_metadata_request() returns the consumer part of the requests + assert str(router._self) == str(est._get_metadata_request()) + # get_metadata_routing() returns the complete request set, consumer and + # router included. + assert str(router._self) != str(est.get_metadata_routing()) + # it should be a copy, not the same object + assert router._self is not est._get_metadata_request() + + +def test_metadata_routing_add(): + # adding one with a string `method_mapping` + router = MetadataRouter(owner="test").add( + method_mapping="fit", + est=RegressorMetadata().set_fit_request(sample_weight="weights"), + ) + assert ( + str(router) + == "{'est': {'mapping': [{'callee': 'fit', 'caller': 'fit'}], 'router': {'fit':" + " {'sample_weight': 'weights'}, 'score': {'sample_weight':" + " }}}}" + ) + + # adding one with an instance of MethodMapping + router = MetadataRouter(owner="test").add( + method_mapping=MethodMapping().add(callee="score", caller="fit"), + est=RegressorMetadata().set_score_request(sample_weight=True), + ) + assert ( + str(router) + == "{'est': {'mapping': [{'callee': 'score', 'caller': 'fit'}], 'router':" + " {'fit': {'sample_weight': }, 'score':" + " {'sample_weight': }}}}" + ) + + +def test_metadata_routing_get_param_names(): + router = ( + MetadataRouter(owner="test") + .add_self( + WeightedMetaRegressor(estimator=RegressorMetadata()).set_fit_request( + sample_weight="self_weights" + ) + ) + .add( + method_mapping="fit", + trs=TransformerMetadata().set_fit_request( + sample_weight="transform_weights" + ), + ) + ) + + assert ( + str(router) + == "{'$self': {'fit': {'sample_weight': 'self_weights'}, 'score':" + " {'sample_weight': }}, 'trs':" + " {'mapping': [{'callee': 'fit', 'caller': 'fit'}], 'router': {'fit':" + " {'brand': , 'sample_weight':" + " 'transform_weights'}, 'transform': {'sample_weight':" + " }}}}" + ) + + assert router._get_param_names( + method="fit", return_alias=True, ignore_self=False + ) == {"transform_weights", "brand", "self_weights"} + # return_alias=False will return original names for "self" + assert router._get_param_names( + method="fit", return_alias=False, ignore_self=False + ) == {"sample_weight", "brand", "transform_weights"} + # ignoring self would remove "sample_weight" + assert router._get_param_names( + method="fit", return_alias=False, ignore_self=True + ) == {"brand", "transform_weights"} + # return_alias is ignored when ignore_self=True + assert router._get_param_names( + method="fit", return_alias=True, ignore_self=True + ) == router._get_param_names(method="fit", return_alias=False, ignore_self=True) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index aa056e92b3d12..ed2cccd3d903d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -21,6 +21,8 @@ import numpy as np from scipy.sparse import issparse +from . import metadata_routing + from .murmurhash import murmurhash3_32 from .class_weight import compute_class_weight, compute_sample_weight from . import _joblib @@ -77,6 +79,7 @@ "DataConversionWarning", "estimator_html_repr", "Bunch", + "metadata_routing", ] IS_PYPY = platform.python_implementation() == "PyPy" diff --git a/sklearn/utils/_metadata_requests.py b/sklearn/utils/_metadata_requests.py new file mode 100644 index 0000000000000..1e28ed6c733cc --- /dev/null +++ b/sklearn/utils/_metadata_requests.py @@ -0,0 +1,1182 @@ +""" +Metadata Routing Utility +""" + +# Author: Adrin Jalali +# License: BSD 3 clause + +import inspect +from copy import deepcopy +from enum import Enum +from warnings import warn +from collections import namedtuple +from typing import Union, Optional +from ._bunch import Bunch + +# This namedtuple is used to store a (mapping, routing) pair. Mapping is a +# MethodMapping object, and routing is the output of `get_metadata_routing`. +# MetadataRouter stores a collection of these namedtuples. +RouterMappingPair = namedtuple("RouterMappingPair", ["mapping", "router"]) + +# A namedtuple storing a single method route. A collection of these namedtuples +# is stored in a MetadataRouter. +MethodPair = namedtuple("MethodPair", ["callee", "caller"]) + + +class RequestType(Enum): + """A metadata is requested either with a string alias or this enum. + + .. versionadded:: 1.2 + """ + + # Metadata is not requested. It will not be routed to the object having the + # request value as UNREQUESTED. + UNREQUESTED = False + # Metadata is requested, and will be routed to the requesting object. There + # will be no error if the metadata is not provided. + REQUESTED = True + # Default metadata request configuration. It should not be passed, and if + # present, an error is raised for the user to explicitly set the request + # value. + ERROR_IF_PASSED = None + # this is used in `__metadata_request__*` attributes to indicate that a + # metadata is not present even though it may be present in the + # corresponding method's signature. + UNUSED = "$UNUSED$" + # this is used whenever a default value is changed, and therefore the user + # should explicitly set the value, otherwise a warning is shown. An example + # is when a meta-estimator is only a router, but then becomes also a + # consumer in a new release. + WARN = "$WARN$" + + @classmethod + def is_alias(cls, item): + """Check if an item is a valid alias. + + Parameters + ---------- + item : object + The given item to be checked if it can be an alias. + + Returns + ------- + result : bool + Whether the given item is a valid alias. + """ + try: + cls(item) + except ValueError: + # item is only an alias if it's a valid identifier + return isinstance(item, str) and item.isidentifier() + else: + return False + + @classmethod + def is_valid(cls, item): + """Check if an item is a valid RequestType (and not an alias). + + Parameters + ---------- + item : object + The given item to be checked. + + Returns + ------- + result : bool + Whether the given item is valid. + """ + try: + cls(item) + return True + except ValueError: + return False + + +# this is the default used in `set_{method}_request` methods to indicate no change +# requested by the user. +UNCHANGED = "$UNCHANGED$" + +# Only the following methods are supported in the routing mechanism. Adding new +# methods at the moment involves monkeypatching this list. +METHODS = [ + "fit", + "partial_fit", + "predict", + "score", + "split", + "transform", + "inverse_transform", +] + + +# These strings are used to dynamically generate the docstrings for +# set_{method}_request methods. +REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method. + + Please see :ref:`User Guide ` on how the routing + mechanism works. + + .. versionadded:: 1.2 + + Parameters + ---------- +""" +REQUESTER_DOC_PARAM = """ {metadata} : RequestType, str, True, False, or None, \ + default=UNCHANGED + Whether ``{metadata}`` should be passed to ``{method}`` by + meta-estimators or not, and if yes, should it have an alias. + + - True or RequestType.REQUESTED: ``{metadata}`` is requested, and \ +passed to ``{method}`` if provided. The request is ignored if \ +``{metadata}`` is not provided. + + - False or RequestType.UNREQUESTED: ``{metadata}`` is not requested \ +and the meta-estimator will not pass it to ``{method}``. + + - None or RequestType.ERROR_IF_PASSED: ``{metadata}`` is not \ +requested, and the meta-estimator will raise an error if the user provides \ +``{metadata}``. + + - str: ``{metadata}`` should be passed to the meta-estimator with \ +this given alias instead of the original name. + + The default (UNCHANGED) retains the existing request. This allows + you to change the request for some parameters and not others. + +""" +REQUESTER_DOC_RETURN = """ Returns + ------- + self : object + The updated object. +""" + + +class MethodMetadataRequest: + """A prescription of how metadata is to be passed to a single method. + + Refer to :class:`MetadataRequest` for how this class is used. + + .. versionadded:: 1.2 + + Parameters + ---------- + owner : str + A display name for the object owning these requests. + + method : str + The name of the method to which these requests belong. + """ + + def __init__(self, owner, method): + self._requests = dict() + self.owner = owner + self.method = method + + @property + def requests(self): + """Dictionary of the form: ``{key: alias}``.""" + return self._requests + + def add_request( + self, + *, + param, + alias, + ): + """Add request info for a metadata. + + Parameters + ---------- + param : str + The property for which a request is set. + + alias : str, RequestType, or {True, False, None} + Specifies which metadata should be routed to `param` + + - str: the name (or alias) of metadata given to a meta-estimator that + should be routed to this parameter. + + - True or RequestType.REQUESTED: requested + + - False or RequestType.UNREQUESTED: not requested + + - None or RequestType.ERROR_IF_PASSED: error if passed + """ + if RequestType.is_valid(alias): + alias = RequestType(alias) + elif not RequestType.is_alias(alias): + raise ValueError( + "alias should be either a valid identifier or one of " + "{None, True, False}, or a RequestType." + ) + + if alias == param: + alias = RequestType.REQUESTED + + if alias == RequestType.UNUSED and param in self._requests: + del self._requests[param] + else: + self._requests[param] = alias + + return self + + def _get_param_names(self, return_alias): + """Get names of all metadata that can be consumed or routed by this method. + + This method returns the names of all metadata, even the UNREQUESTED + ones. + + Parameters + ---------- + return_alias : bool + Controls whether original or aliased names should be returned. If + ``False``, aliases are ignored and original names are returned. + + Returns + ------- + names : set of str + A set of strings with the names of all parameters. + """ + return set( + alias if return_alias and not RequestType.is_valid(alias) else prop + for prop, alias in self._requests.items() + if not RequestType.is_valid(alias) + or RequestType(alias) != RequestType.UNREQUESTED + ) + + def _check_warnings(self, *, params): + """Check whether metadata is passed which is marked as WARN. + + If any metadata is passed which is marked as WARN, a warning is raised. + + Parameters + ---------- + params : dict + The metadata passed to a method. + """ + params = {} if params is None else params + warn_params = { + prop + for prop, alias in self._requests.items() + if alias == RequestType.WARN and prop in params + } + for param in warn_params: + warn( + f"Support for {param} has recently been added to this class. " + "To maintain backward compatibility, it is ignored now. " + "You can set the request value to RequestType.UNREQUESTED " + "to silence this warning, or to RequestType.REQUESTED to " + "consume and use the metadata." + ) + + def _route_params(self, params=None): + """Prepare the given parameters to be passed to the method. + + The output of this method can be used directly as the input to the + corresponding method as extra props. + + Parameters + ---------- + params : dict + A dictionary of provided metadata. + + Returns + ------- + params : Bunch + A :class:`~utils.Bunch` of {prop: value} which can be given to the + corresponding method. + """ + self._check_warnings(params=params) + params = {} if params is None else params + args = {arg: value for arg, value in params.items() if value is not None} + res = Bunch() + for prop, alias in self._requests.items(): + if RequestType.is_valid(alias): + alias = RequestType(alias) + + if alias == RequestType.UNREQUESTED or alias == RequestType.WARN: + continue + elif alias == RequestType.REQUESTED and prop in args: + res[prop] = args[prop] + elif alias == RequestType.ERROR_IF_PASSED and prop in args: + raise ValueError( + f"{prop} is passed but is not explicitly set as " + f"requested or not for {self.owner}.{self.method}" + ) + elif alias in args: + res[prop] = args[alias] + return res + + def _serialize(self): + """Serialize the object. + + Returns + ------- + obj : dict + A serialized version of the instance in the form of a dictionary. + """ + return { + prop: RequestType(alias) if RequestType.is_valid(alias) else alias + for prop, alias in self._requests.items() + } + + def __repr__(self): + return str(self._serialize()) + + def __str__(self): + return str(repr(self)) + + +class MetadataRequest: + """Contains the metadata request info of a consumer. + + Instances of :class:`MethodMetadataRequest` are used in this class for each + available method under `metadatarequest.{method}`. + + Consumer-only classes such as simple estimators return a serialized + version of this class as the output of `get_metadata_routing()`. + + .. versionadded:: 1.2 + + Parameters + ---------- + owner : str + The name of the object to which these requests belong. + """ + + # this is here for us to use this attribute's value instead of doing + # `isinstance` in our checks, so that we avoid issues when people vendor + # this file instead of using it directly from scikit-learn. + _type = "metadata_request" + + def __init__(self, owner): + for method in METHODS: + setattr(self, method, MethodMetadataRequest(owner=owner, method=method)) + + def _get_param_names(self, method, return_alias, ignore_self=None): + """Get names of all metadata that can be consumed or routed by specified \ + method. + + This method returns the names of all metadata, even the UNREQUESTED + ones. + + Parameters + ---------- + method : str + The name of the method for which metadata names are requested. + + return_alias : bool + Controls whether original or aliased names should be returned. If + ``False``, aliases are ignored and original names are returned. + + ignore_self : bool + Ignored. Present for API compatibility. + + Returns + ------- + names : set of str + A set of strings with the names of all parameters. + """ + return getattr(self, method)._get_param_names(return_alias=return_alias) + + def _route_params(self, *, method, params): + """Prepare the given parameters to be passed to the method. + + The output of this method can be used directly as the input to the + corresponding method as extra keyword arguments to pass metadata. + + Parameters + ---------- + method : str + The name of the method for which the parameters are requested and + routed. + + params : dict + A dictionary of provided metadata. + + Returns + ------- + params : Bunch + A :class:`~utils.Bunch` of {prop: value} which can be given to the + corresponding method. + """ + return getattr(self, method)._route_params(params=params) + + def _check_warnings(self, *, method, params): + """Check whether metadata is passed which is marked as WARN. + + If any metadata is passed which is marked as WARN, a warning is raised. + + Parameters + ---------- + method : str + The name of the method for which the warnings should be checked. + + params : dict + The metadata passed to a method. + """ + getattr(self, method)._check_warnings(params=params) + + def _serialize(self): + """Serialize the object. + + Returns + ------- + obj : dict + A serialized version of the instance in the form of a dictionary. + """ + output = dict() + for method in METHODS: + mmr = getattr(self, method) + if len(mmr.requests): + output[method] = mmr._serialize() + return output + + def __repr__(self): + return str(self._serialize()) + + def __str__(self): + return str(repr(self)) + + +class MethodMapping: + """Stores the mapping between callee and caller methods for a router. + + This class is primarily used in a ``get_metadata_routing()`` of a router + object when defining the mapping between a sub-object (a sub-estimator or a + scorer) to the router's methods. It stores a collection of ``Route`` + namedtuples. + + Iterating through an instance of this class will yield named + ``MethodPair(callee, caller)`` tuples. + + .. versionadded:: 1.2 + """ + + def __init__(self): + self._routes = [] + + def __iter__(self): + return iter(self._routes) + + def add(self, *, callee, caller): + """Add a method mapping. + + Parameters + ---------- + callee : str + Child object's method name. This method is called in ``caller``. + + caller : str + Parent estimator's method name in which the ``callee`` is called. + + Returns + ------- + self : MethodMapping + Returns self. + """ + if callee not in METHODS: + raise ValueError( + f"Given callee:{callee} is not a valid method. Valid methods are:" + f" {METHODS}" + ) + if caller not in METHODS: + raise ValueError( + f"Given caller:{caller} is not a valid method. Valid methods are:" + f" {METHODS}" + ) + self._routes.append(MethodPair(callee=callee, caller=caller)) + return self + + def _serialize(self): + """Serialize the object. + + Returns + ------- + obj : list + A serialized version of the instance in the form of a list. + """ + result = list() + for route in self._routes: + result.append({"callee": route.callee, "caller": route.caller}) + return result + + @classmethod + def from_str(cls, route): + """Construct an instance from a string. + + Parameters + ---------- + route : str + A string representing the mapping, it can be: + + - `"one-to-one"`: a one to one mapping for all methods. + - `"method"`: the name of a single method. + + Returns + ------- + obj : MethodMapping + A :class:`~utils.metadata_requests.MethodMapping` instance + constructed from the given string. + """ + routing = cls() + if route == "one-to-one": + for method in METHODS: + routing.add(callee=method, caller=method) + elif route in METHODS: + routing.add(callee=route, caller=route) + else: + raise ValueError("route should be 'one-to-one' or a single method!") + return routing + + def __repr__(self): + return str(self._serialize()) + + def __str__(self): + return str(repr(self)) + + +class MetadataRouter: + """Stores and handles metadata routing for a router object. + + This class is used by router objects to store and handle metadata routing. + Routing information is stored as a dictionary of the form ``{"object_name": + RouteMappingPair(method_mapping, routing_info)}``, where ``method_mapping`` + is an instance of :class:`~utils.metadata_requests.MethodMapping` and + ``routing_info`` is either a + :class:`~utils.metadata_requests.MetadataRequest` or a + :class:`~utils.metadata_requests.MetadataRouter` instance. + + .. versionadded:: 1.2 + + Parameters + ---------- + owner : str + The name of the object to which these requests belong. + """ + + # this is here for us to use this attribute's value instead of doing + # `isinstance`` in our checks, so that we avoid issues when people vendor + # this file instad of using it directly from scikit-learn. + _type = "metadata_router" + + def __init__(self, owner): + self._route_mappings = dict() + # `_self` is used if the router is also a consumer. _self, (added using + # `add_self()`) is treated differently from the other objects which are + # stored in _route_mappings. + self._self = None + self.owner = owner + + def add_self(self, obj): + """Add `self` (as a consumer) to the routing. + + This method is used if the router is also a consumer, and hence the + router itself needs to be included in the routing. The passed object + can be an estimator or a + :class:``~utils.metadata_requests.MetadataRequest``. + + A router should add itself using this method instead of `add` since it + should be treated differently than the other objects to which metadata + is routed by the router. + + Parameters + ---------- + obj : object + This is typically the router instance, i.e. `self` in a + ``get_metadata_routing()`` implementation. It can also be a + ``MetadataRequest`` instance. + + Returns + ------- + self : MetadataRouter + Returns `self`. + """ + if getattr(obj, "_type", None) == "metadata_request": + self._self = deepcopy(obj) + elif hasattr(obj, "_get_metadata_request"): + self._self = deepcopy(obj._get_metadata_request()) + else: + raise ValueError( + "Given `obj` is neither a `MetadataRequest` nor does it implement the" + " required API. Inheriting from `BaseEstimator` implements the required" + " API." + ) + return self + + def add(self, *, method_mapping, **objs): + """Add named objects with their corresponding method mapping. + + Parameters + ---------- + method_mapping : MethodMapping or str + The mapping between the child and the parent's methods. If str, the + output of :func:`~utils.metadata_requests.MethodMapping.from_str` + is used. + + **objs : dict + A dictionary of objects from which metadata is extracted by calling + :func:`~utils.metadata_requests.get_routing_for_object` on them. + + Returns + ------- + self : MetadataRouter + Returns `self`. + """ + if isinstance(method_mapping, str): + method_mapping = MethodMapping.from_str(method_mapping) + else: + method_mapping = deepcopy(method_mapping) + + for name, obj in objs.items(): + self._route_mappings[name] = RouterMappingPair( + mapping=method_mapping, router=get_routing_for_object(obj) + ) + return self + + def _get_param_names(self, *, method, return_alias, ignore_self): + """Get names of all metadata that can be consumed or routed by specified \ + method. + + This method returns the names of all metadata, even the UNREQUESTED + ones. + + Parameters + ---------- + method : str + The name of the method for which metadata names are requested. + + return_alias : bool + Controls whether original or aliased names should be returned, + which only applies to the stored `self`. If no `self` routing + object is stored, this parameter has no effect. + + ignore_self : bool + If `self._self` should be ignored. This is used in `_route_params`. + If ``True``, ``return_alias`` has no effect. + + Returns + ------- + names : set of str + A set of strings with the names of all parameters. + """ + res = set() + if self._self and not ignore_self: + res = res.union( + self._self._get_param_names(method=method, return_alias=return_alias) + ) + + for name, route_mapping in self._route_mappings.items(): + for callee, caller in route_mapping.mapping: + if caller == method: + res = res.union( + route_mapping.router._get_param_names( + method=callee, return_alias=True, ignore_self=False + ) + ) + return set(res) + + def _route_params(self, *, params, method): + """Prepare the given parameters to be passed to the method. + + This is used when a router is used as a child object of another router. + The parent router then passes all parameters understood by the child + object to it and delegates their validation to the child. + + The output of this method can be used directly as the input to the + corresponding method as extra props. + + Parameters + ---------- + method : str + The name of the method for which the parameters are requested and + routed. + + params : dict + A dictionary of provided metadata. + + Returns + ------- + params : Bunch + A :class:`~utils.Bunch` of {prop: value} which can be given to the + corresponding method. + """ + res = Bunch() + if self._self: + res.update(self._self._route_params(params=params, method=method)) + + param_names = self._get_param_names( + method=method, return_alias=True, ignore_self=True + ) + child_params = { + key: value for key, value in params.items() if key in param_names + } + for key in set(res.keys()).intersection(child_params.keys()): + # conflicts are okay if the passed objects are the same, but it's + # an issue if they're different objects. + if child_params[key] is not res[key]: + raise ValueError( + f"In {self.owner}, there is a conflict on {key} between what is" + " requested for this estimator and what is requested by its" + " children. You can resolve this conflict by using an alias for" + " the child estimator(s) requested metadata." + ) + + res.update(child_params) + return res + + def route_params(self, *, caller, params): + """Return the input parameters requested by child objects. + + The output of this method is a bunch, which includes the inputs for all + methods of each child object that are used in the router's `caller` + method. + + If the router is also a consumer, it also checks for warnings of + `self`'s/consumer's requested metadata. + + Parameters + ---------- + caller : str + The name of the method for which the parameters are requested and + routed. If called inside the :term:`fit` method of a router, it + would be `"fit"`. + + params : dict + A dictionary of provided metadata. + + Returns + ------- + params : Bunch + A :class:`~utils.Bunch` of the form + ``{"object_name": {"method_name": {prop: value}}}`` which can be + used to pass the required metadata to corresponding methods or + corresponding child objects. + """ + if self._self: + self._self._check_warnings(params=params, method=caller) + + res = Bunch() + for name, route_mapping in self._route_mappings.items(): + router, mapping = route_mapping.router, route_mapping.mapping + + res[name] = Bunch() + for _callee, _caller in mapping: + if _caller == caller: + res[name][_callee] = router._route_params( + params=params, method=_callee + ) + return res + + def validate_metadata(self, *, method, params): + """Validate given metadata for a method. + + This raises a ``ValueError`` if some of the passed metadata are not + understood by child objects. + + Parameters + ---------- + method : str + The name of the method for which the parameters are requested and + routed. If called inside the :term:`fit` method of a router, it + would be `"fit"`. + + params : dict + A dictionary of provided metadata. + """ + param_names = self._get_param_names( + method=method, return_alias=False, ignore_self=False + ) + if self._self: + self_params = self._self._get_param_names(method=method, return_alias=False) + else: + self_params = set() + extra_keys = set(params.keys()) - param_names - self_params + if extra_keys: + raise TypeError( + f"{method} got unexpected argument(s) {extra_keys}, which are " + "not requested metadata in any object." + ) + + def _serialize(self): + """Serialize the object. + + Returns + ------- + obj : dict + A serialized version of the instance in the form of a dictionary. + """ + res = dict() + if self._self: + res["$self"] = self._self._serialize() + for name, route_mapping in self._route_mappings.items(): + res[name] = dict() + res[name]["mapping"] = route_mapping.mapping._serialize() + res[name]["router"] = route_mapping.router._serialize() + + return res + + def __iter__(self): + if self._self: + yield "$self", RouterMappingPair( + mapping=MethodMapping.from_str("one-to-one"), router=self._self + ) + for name, route_mapping in self._route_mappings.items(): + yield (name, route_mapping) + + def __repr__(self): + return str(self._serialize()) + + def __str__(self): + return str(repr(self)) + + +def get_routing_for_object(obj=None): + """Get a ``Metadata{Router, Request}`` instance from the given object. + + This function returns a + :class:`~utils.metadata_request.MetadataRouter` or a + :class:`~utils.metadata_request.MetadataRequest` from the given input. + + This function always returns a copy or an instance constructed from the + intput, such that changing the output of this function will not change the + original object. + + .. versionadded:: 1.2 + + Parameters + ---------- + obj : object + - If the object is already a + :class:`~utils.metadata_requests.MetadataRequest` or a + :class:`~utils.metadata_requests.MetadataRouter`, return a copy + of that. + - If the object provides a `get_metadata_routing` method, return a copy + of the output of that method. + - Returns an empty :class:`~utils.metadata_requests.MetadataRequest` + otherwise. + + Returns + ------- + obj : MetadataRequest or MetadataRouting + A ``MetadataRequest`` or a ``MetadataRouting`` taken or created from + the given object. + """ + if obj is None: + return MetadataRequest(owner=None) + + # doing this instead of a try/except since an AttributeError could be raised + # for other reasons. + if hasattr(obj, "get_metadata_routing"): + return deepcopy(obj.get_metadata_routing()) + + if getattr(obj, "_type", None) in ["metadata_request", "metadata_router"]: + return deepcopy(obj) + + return MetadataRequest(owner=None) + + +class RequestMethod: + """ + A descriptor for request methods. + + .. versionadded:: 1.2 + + Parameters + ---------- + name : str + The name of the method for which the request function should be + created, e.g. ``"fit"`` would create a ``set_fit_request`` function. + + keys : list of str + A list of strings which are accepted parameters by the created + function, e.g. ``["sample_weight"]`` if the corresponding method + accepts it as a metadata. + + Notes + ----- + This class is a descriptor [1]_ and uses PEP-362 to set the signature of + the returned function [2]_. + + References + ---------- + .. [1] https://docs.python.org/3/howto/descriptor.html + + .. [2] https://www.python.org/dev/peps/pep-0362/ + """ + + def __init__(self, name, keys): + self.name = name + self.keys = keys + + def __get__(self, instance, owner): + # we would want to have a method which accepts only the expected args + def func(**kw): + """Updates the request for provided parameters + + This docstring is overwritten below. + See REQUESTER_DOC for expected functionality + """ + if set(kw) - set(self.keys): + raise TypeError( + f"Unexpected args: {set(kw) - set(self.keys)}. Accepted arguments" + f" are: {set(self.keys)}" + ) + + requests = instance._get_metadata_request() + method_metadata_request = getattr(requests, self.name) + + for prop, alias in kw.items(): + if alias is not UNCHANGED: + method_metadata_request.add_request(param=prop, alias=alias) + instance._metadata_request = requests + + return instance + + # Now we set the relevant attributes of the function so that it seems + # like a normal method to the end user, with known expected arguments. + func.__name__ = f"set_{self.name}_request" + params = [ + inspect.Parameter( + name="self", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=owner, + ) + ] + params.extend( + [ + inspect.Parameter( + k, + inspect.Parameter.KEYWORD_ONLY, + default=UNCHANGED, + annotation=Optional[Union[RequestType, str]], + ) + for k in self.keys + ] + ) + func.__signature__ = inspect.Signature( + params, + return_annotation=owner, + ) + doc = REQUESTER_DOC.format(method=self.name) + for metadata in self.keys: + doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name) + doc += REQUESTER_DOC_RETURN + func.__doc__ = doc + return func + + +class _MetadataRequester: + """Mixin class for adding metadata request functionality. + + .. versionadded:: 1.2 + """ + + def __init_subclass__(cls, **kwargs): + """Set the ``set_{method}_request`` methods. + + This uses PEP-487 [1]_ to set the ``set_{method}_request`` methods. It + looks for the information available in the set default values which are + set using ``__metadata_request__*`` class attributes, or inferred + from method signatures. + + The ``__metadata_request__*`` class attributes are used when a method + does not explicitly accept a metadata through its arguments or if the + developer would like to specify a request value for those metadata + which are different from the default ``RequestType.ERROR_IF_PASSED``. + + References + ---------- + .. [1] https://www.python.org/dev/peps/pep-0487 + """ + try: + requests = cls._get_default_requests() + except Exception: + # if there are any issues in the default values, it will be raised + # when ``get_metadata_routing`` is called. Here we are going to + # ignore all the issues such as bad defaults etc. + super().__init_subclass__(**kwargs) + return + + for method in METHODS: + mmr = getattr(requests, method) + # set ``set_{method}_request``` methods + if not len(mmr.requests): + continue + setattr( + cls, + f"set_{method}_request", + RequestMethod(method, sorted(mmr.requests.keys())), + ) + super().__init_subclass__(**kwargs) + + @classmethod + def _build_request_for_signature(cls, method): + """Build the `MethodMetadataRequest` for a method using its signature. + + This method takes all arguments from the method signature and uses + ``RequestType.ERROR_IF_PASSED`` as their default request value, except + ``X``, ``y``, ``*args``, and ``**kwargs``. + + Parameters + ---------- + method : str + The name of the method. + + Returns + ------- + method_request : MethodMetadataRequest + The prepared request using the method's signature. + """ + mmr = MethodMetadataRequest(owner=cls.__name__, method=method) + # Here we use `isfunction` instead of `ismethod` because calling `getattr` + # on a class instead of an instance returns an unbound function. + if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)): + return mmr + # ignore the first parameter of the method, which is usually "self" + params = list(inspect.signature(getattr(cls, method)).parameters.items())[1:] + for pname, param in params: + if pname in {"X", "y", "Y"}: + continue + if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}: + continue + mmr.add_request( + param=pname, + alias=RequestType.ERROR_IF_PASSED, + ) + return mmr + + @classmethod + def _get_default_requests(cls): + """Collect default request values. + + This method combines the information present in ``metadata_request__*`` + class attributes, as well as determining request keys from method + signatures. + """ + requests = MetadataRequest(owner=cls.__name__) + for method in METHODS: + setattr(requests, method, cls._build_request_for_signature(method=method)) + + # Then overwrite those defaults with the ones provided in + # __metadata_request__* attributes. Defaults set in + # __metadata_request__* attributes take precedence over signature + # sniffing. + + # need to go through the MRO since this is a class attribute and + # ``vars`` doesn't report the parent class attributes. We go through + # the reverse of the MRO so that child classes have precedence over + # their parents. + defaults = dict() + for base_class in reversed(inspect.getmro(cls)): + base_defaults = { + attr: value + for attr, value in vars(base_class).items() + if "__metadata_request__" in attr + } + defaults.update(base_defaults) + defaults = dict(sorted(defaults.items())) + + for attr, value in defaults.items(): + # we don't check for attr.startswith() since python prefixes attrs + # starting with __ with the `_ClassName`. + substr = "__metadata_request__" + method = attr[attr.index(substr) + len(substr) :] + for prop, alias in value.items(): + getattr(requests, method).add_request(param=prop, alias=alias) + return requests + + def _get_metadata_request(self): + """Get requested data properties. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + request : MetadataRequest + A :class:`~.utils.metadata_requests.MetadataRequest` instance. + """ + if hasattr(self, "_metadata_request"): + requests = get_routing_for_object(self._metadata_request) + else: + requests = self._get_default_requests() + + return requests + + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + routing : MetadataRequest + A :class:`~utils.metadata_routing.MetadataRequest` encapsulating + routing information. + """ + return self._get_metadata_request() + + +def process_routing(obj, method, other_params, **kwargs): + """Validate and route input parameters. + + This function is used inside a router's method, e.g. :term:`fit`, + to validate the metadata and handle the routing. + + Assuming this signature: ``fit(self, X, y, sample_weight=None, **fit_params)``, + a call to this function would be: + ``process_routing(self, fit_params, sample_weight=sample_weight)``. + + .. versionadded:: 1.2 + + Parameters + ---------- + obj : object + An object implementing ``get_metadata_routing``. Typically a + meta-estimator. + + method : str + The name of the router's method in which this function is called. + + other_params : dict + A dictionary of extra parameters passed to the router's method, + e.g. ``**fit_params`` passed to a meta-estimator's :term:`fit`. + + **kwargs : dict + Parameters explicitly accepted and included in the router's method + signature. + + Returns + ------- + routed_params : Bunch + A :class:`~utils.Bunch` of the form ``{"object_name": {"method_name": + {prop: value}}}`` which can be used to pass the required metadata to + corresponding methods or corresponding child objects. The object names + are those defined in `obj.get_metadata_routing()`. + """ + if not hasattr(obj, "get_metadata_routing"): + raise AttributeError( + f"This {repr(obj.__class__.__name__)} has not implemented the routing" + " method `get_metadata_routing`." + ) + if method not in METHODS: + raise TypeError( + f"Can only route and process input on these methods: {METHODS}, " + f"while the passed method is: {method}." + ) + + # We take the extra params (**fit_params) which is passed as `other_params` + # and add the explicitly passed parameters (passed as **kwargs) to it. This + # is equivalent to a code such as this in a router: + # if sample_weight is not None: + # fit_params["sample_weight"] = sample_weight + all_params = other_params if other_params is not None else dict() + all_params.update(kwargs) + + request_routing = get_routing_for_object(obj) + request_routing.validate_metadata(params=all_params, method=method) + routed_params = request_routing.route_params(params=all_params, caller=method) + + return routed_params diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 29a1491e70423..e5bb8b5eaa65e 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -3032,6 +3032,8 @@ def check_no_attributes_set_in_init(name, estimator_orig): # Test for no setting apart from parameters during init invalid_attr = set(vars(estimator)) - set(init_params) - set(parents_init_params) + # Ignore private attributes + invalid_attr = set([attr for attr in invalid_attr if not attr.startswith("_")]) assert not invalid_attr, ( "Estimator %s should not set any attribute apart" " from parameters during init. Found attributes %s." diff --git a/sklearn/utils/metadata_routing.py b/sklearn/utils/metadata_routing.py new file mode 100644 index 0000000000000..89ba5b5a9253c --- /dev/null +++ b/sklearn/utils/metadata_routing.py @@ -0,0 +1,17 @@ +""" +Metadata Routing Utility Public API. + +metadata_routing is not a separate sub-folder since that would result in a +circular import issue. +""" + +# Author: Adrin Jalali +# License: BSD 3 clause + +from ._metadata_requests import RequestType # noqa +from ._metadata_requests import get_routing_for_object # noqa +from ._metadata_requests import MetadataRouter # noqa +from ._metadata_requests import MetadataRequest # noqa +from ._metadata_requests import MethodMapping # noqa +from ._metadata_requests import process_routing # noqa +from ._metadata_requests import _MetadataRequester # noqa diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 3a88b4431fe86..b71cc8423e75e 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -653,6 +653,10 @@ class NonConformantEstimatorNoParamSet(BaseEstimator): def __init__(self, you_should_set_this_=None): pass + class ConformantEstimatorClassAttribute(BaseEstimator): + # making sure our __metadata_request__* class attributes are okay! + __metadata_request__fit = {"foo": True} + msg = ( "Estimator estimator_name should not set any" " attribute apart from parameters during init." @@ -672,6 +676,17 @@ def __init__(self, you_should_set_this_=None): "estimator_name", NonConformantEstimatorNoParamSet() ) + # a private class attribute is okay! + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute() + ) + # also check if cloning an estimator which has non-default set requests is + # fine. Setting a non-default value via `set_{method}_request` sets the + # private _metadata_request instance attribute which is copied in `clone`. + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute().set_fit_request(foo=True) + ) + def test_check_estimator_pairwise(): # check that check_estimator() works on estimator with _pairwise