|
| 1 | + |
| 2 | +.. _metadata_routing: |
| 3 | + |
| 4 | +.. currentmodule:: sklearn |
| 5 | + |
| 6 | +.. TODO: update doc/conftest.py once document is updated and examples run. |
| 7 | +
|
| 8 | +Metadata Routing |
| 9 | +================ |
| 10 | + |
| 11 | +.. note:: |
| 12 | + The Metadata Routing API is experimental, and is not implemented yet for many |
| 13 | + estimators. It may change without the usual deprecation cycle. By default |
| 14 | + this feature is not enabled. You can enable this feature by setting the |
| 15 | + ``enable_metadata_routing`` flag to ``True``: |
| 16 | + |
| 17 | + >>> import sklearn |
<
B47E
tr class="diff-line-row"> | 18 | + >>> sklearn.set_config(enable_metadata_routing=True) | | 19 | + |
| 20 | +This guide demonstrates how metadata such as ``sample_weight`` can be routed |
| 21 | +and passed along to estimators, scorers, and CV splitters through |
| 22 | +meta-estimators such as :class:`~pipeline.Pipeline` and |
| 23 | +:class:`~model_selection.GridSearchCV`. In order to pass metadata to a method |
| 24 | +such as ``fit`` or ``score``, the object consuming the metadata, must *request* |
| 25 | +it. For estimators and splitters, this is done via ``set_*_request`` methods, |
| 26 | +e.g. ``set_fit_request(...)``, and for scorers this is done via the |
| 27 | +``set_score_request`` method. For grouped splitters such as |
| 28 | +:class:`~model_selection.GroupKFold`, a ``groups`` parameter is requested by |
| 29 | +default. This is best demonstrated by the following examples. |
| 30 | + |
| 31 | +If you are developing a scikit-learn compatible estimator or meta-estimator, |
| 32 | +you can check our related developer guide: |
| 33 | +:ref:`sphx_glr_auto_examples_miscellaneous_plot_metadata_routing.py`. |
| 34 | + |
| 35 | +.. note:: |
| 36 | + Note that the methods and requirements introduced in this document are only |
| 37 | + relevant if you want to pass metadata (e.g. ``sample_weight``) to a method. |
| 38 | + If you're only passing ``X`` and ``y`` and no other parameter / metadata to |
| 39 | + methods such as ``fit``, ``transform``, etc, then you don't need to set |
| 40 | + anything. |
| 41 | + |
| 42 | +Usage Examples |
| 43 | +************** |
| 44 | +Here we present a few examples to show different common use-cases. The examples |
| 45 | +in this section require the following imports and data:: |
| 46 | + |
| 47 | + >>> import numpy as np |
| 48 | + >>> from sklearn.metrics import make_scorer, accuracy_score |
| 49 | + >>> from sklearn.linear_model import LogisticRegressionCV, LogisticRegression |
| 50 | + >>> from sklearn.model_selection import cross_validate, GridSearchCV, GroupKFold |
| 51 | + >>> from sklearn.feature_selection import SelectKBest |
| 52 | + >>> from sklearn.pipeline import make_pipeline |
| 53 | + >>> n_samples, n_features = 100, 4 |
| 54 | + >>> rng = np.random.RandomState(42) |
| 55 | + >>> X = rng.rand(n_samples, n_features) |
| 56 | + >>> y = rng.randint(0, 2, size=n_samples) |
| 57 | + >>> my_groups = rng.randint(0, 10, size=n_samples) |
| 58 | + >>> my_weights = rng.rand(n_samples) |
| 59 | + >>> my_other_weights = rng.rand(n_samples) |
| 60 | + |
| 61 | +Weighted scoring and fitting |
| 62 | +---------------------------- |
| 63 | + |
| 64 | +Here :class:`~model_selection.GroupKFold` requests ``groups`` by default. However, we |
| 65 | +need to explicitly request weights for our scorer and the internal cross validation of |
| 66 | +:class:`~linear_model.LogisticRegressionCV`. Both of these *consumers* know how to use |
| 67 | +metadata called ``sample_weight``:: |
| 68 | + |
| 69 | + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( |
| 70 | + ... sample_weight=True |
| 71 | + ... ) |
| 72 | + >>> lr = LogisticRegressionCV( |
| 73 | + ... cv=GroupKFold(), scoring=weighted_acc, |
| 74 | + ... ).set_fit_request(sample_weight=True) |
| 75 | + >>> cv_results = cross_validate( |
| 76 | + ... lr, |
| 77 | + ... X, |
| 78 | + ... y, |
| 79 | + ... props={"sample_weight": my_weights, "groups": my_groups}, |
| 80 | + ... cv=GroupKFold(), |
| 81 | + ... scoring=weighted_acc, |
| 82 | + ... ) |
| 83 | + |
| 84 | +Note that in this example, ``my_weights`` is passed to both the scorer and |
| 85 | +:class:`~linear_model.LogisticRegressionCV`. |
| 86 | + |
| 87 | +Error handling: if ``props={"sample_weigh": my_weights, ...}`` were passed |
| 88 | +(note the typo), :func:`~model_selection.cross_validate` would raise an error, |
| 89 | +since ``sample_weigh`` was not requested by any of its underlying objects. |
| 90 | + |
| 91 | +Weighted scoring and unweighted fitting |
| 92 | +--------------------------------------- |
| 93 | + |
| 94 | +When passing metadata such as ``sample_weight`` around, all scikit-learn |
| 95 | +estimators require weights to be either explicitly requested or not requested |
| 96 | +(i.e. ``True`` or ``False``) when used in another router such as a |
| 97 | +:class:`~pipeline.Pipeline` or a ``*GridSearchCV``. To perform an unweighted |
| 98 | +fit, we need to configure :class:`~linear_model.LogisticRegressionCV` to not |
| 99 | +request sample weights, so that :func:`~model_selection.cross_validate` does |
| 100 | +not pass the weights along:: |
| 101 | + |
| 102 | + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( |
| 103 | + ... sample_weight=True |
| 104 | + ... ) |
| 105 | + >>> lr = LogisticRegressionCV( |
| 106 | + ... cv=GroupKFold(), scoring=weighted_acc, |
| 107 | + ... ).set_fit_request(sample_weight=False) |
| 108 | + >>> cv_results = cross_validate( |
| 109 | + ... lr, |
| 110 | + ... X, |
| 111 | + ... y, |
| 112 | + ... cv=GroupKFold(), |
| 113 | + ... props={"sample_weight": my_weights, "groups": my_groups}, |
| 114 | + ... scoring=weighted_acc, |
| 115 | + ... ) |
| 116 | + |
| 117 | +If :meth:`linear_model.LogisticRegressionCV.set_fit_request` has not |
| 118 | +been called, :func:`~model_selection.cross_validate` will raise an |
| 119 | +error because ``sample_weight`` is passed in but |
| 120 | +:class:`~linear_model.LogisticRegressionCV` would not be explicitly configured |
| 121 | +to recognize the weights. |
| 122 | + |
| 123 | +Unweighted feature selection |
| 124 | +---------------------------- |
| 125 | + |
| 126 | +Setting request values for metadata are only required if the object, e.g. estimator, |
| 127 | +scorer, etc., is a consumer of that metadata Unlike |
| 128 | +:class:`~linear_model.LogisticRegressionCV`, :class:`~feature_selection.SelectKBest` |
| 129 | +doesn't consume weights and therefore no request value for ``sample_weight`` on its |
| 130 | +instance is set and ``sample_weight`` is not routed to it:: |
| 131 | + |
| 132 | + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( |
| 133 | + ... sample_weight=True |
| 134 | + ... ) |
| 135 | + >>> lr = LogisticRegressionCV( |
| 136 | + ... cv=GroupKFold(), scoring=weighted_acc, |
| 137 | + ... ).set_fit_request(sample_weight=True) |
| 138 | + >>> sel = SelectKBest(k=2) |
| 139 | + >>> pipe = make_pipeline(sel, lr) |
| 140 | + >>> cv_results = cross_validate( |
| 141 | + ... pipe, |
| 142 | + ... X, |
| 143 | + ... y, |
| 144 | + ... cv=GroupKFold(), |
| 145 | + ... props={"sample_weight": my_weights, "groups": my_groups}, |
| 146 | + ... scoring=weighted_acc, |
| 147 | + ... ) |
| 148 | + |
| 149 | +Advanced: Different scoring and fitting weights |
| 150 | +----------------------------------------------- |
| 151 | + |
| 152 | +Despite :func:`~metrics.make_scorer` and |
| 153 | +:class:`~linear_model.LogisticRegressionCV` both expecting the key |
| 154 | +``sample_weight``, we can use aliases to pass different weights to different |
| 155 | +consumers. In this example, we pass ``scoring_weight`` to the scorer, and |
| 156 | +``fitting_weight`` to :class:`~linear_model.LogisticRegressionCV`:: |
| 157 | + |
| 158 | + >>> weighted_acc = make_scorer(accuracy_score).set_score_request( |
| 159 | + ... sample_weight="scoring_weight" |
| 160 | + ... ) |
| 161 | + >>> lr = LogisticRegressionCV( |
| 162 | + ... cv=GroupKFold(), scoring=weighted_acc, |
| 163 | + ... ).set_fit_request(sample_weight="fitting_weight") |
| 164 | + >>> cv_results = cross_validate( |
| 165 | + ... lr, |
| 166 | + ... X, |
| 167 | + ... y, |
| 168 | + ... cv=GroupKFold(), |
| 169 | + ... props={ |
| 170 | + ... "scoring_weight": my_weights, |
| 171 | + ... "fitting_weight": my_other_weights, |
| 172 | + ... "groups": my_groups, |
| 173 | + ... }, |
| 174 | + ... scoring=weighted_acc, |
| 175 | + ... ) |
| 176 | + |
| 177 | +API Interface |
| 178 | +************* |
| 179 | + |
| 180 | +A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which |
| 181 | +accepts and uses some metadata in at least one of its methods (``fit``, |
| 182 | +``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``). |
| 183 | +Meta-estimators which only forward the metadata to other objects (the child |
| 184 | +estimator, scorers, or splitters) and don't use the metadata themselves are not |
| 185 | +consumers. (Meta-)Estimators which route metadata to other objects are |
| 186 | +*routers*. A(n) (meta-)estimator can be a consumer and a router at the same time. |
| 187 | +(Meta-)Estimators and splitters expose a ``set_*_request`` method for each |
| 188 | +method which accepts at least one metadata. For instance, if an estimator |
| 189 | +supports ``sample_weight`` in ``fit`` and ``score``, it exposes |
| 190 | +``estimator.set_fit_request(sample_weight=value)`` and |
| 191 | +``estimator.set_score_request(sample_weight=value)``. Here ``value`` can be: |
| 192 | + |
| 193 | +- ``True``: method requests a ``sample_weight``. This means if the metadata is |
| 194 | + provided, it will be used, otherwise no error is raised. |
| 195 | +- ``False``: method does not request a ``sample_weight``. |
| 196 | +- ``None``: router will raise an error if ``sample_weight`` is passed. This is |
| 197 | + in almost all cases the default value when an object is instantiated and |
| 198 | + ensures the user sets the metadata requests explicitly when a metadata is |
| 199 | + passed. The only exception are ``Group*Fold`` splitters. |
| 200 | +- ``"param_name"``: if this estimator is used in a meta-estimator, the |
| 201 | + meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this |
| 202 | + estimator. This means the mapping between the metadata required by the |
| 203 | + object, e.g. ``sample_weight`` and what is provided by the user, e.g. |
| 204 | + ``my_weights`` is done at the router level, and not by the object, e.g. |
| 205 | + estimator, itself. |
| 206 | + |
| 207 | +Metadata are requested in the same way for scorers using ``set_score_request``. |
| 208 | + |
| 209 | +If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata |
| 210 | +request for all objects which potentially can consume ``sample_weight`` should |
| 211 | +be set by the user, otherwise an error is raised by the router object. For |
| 212 | +example, the following code raises an error, since it hasn't been explicitly |
| 213 | +specified whether ``sample_weight`` should be passed to the estimator's scorer |
| 214 | +or not:: |
| 215 | + |
| 216 | + >>> param_grid = {"C": [0.1, 1]} |
| 217 | + >>> lr = LogisticRegression().set_fit_request(sample_weight=True) |
| 218 | + >>> try: |
| 219 | + ... GridSearchCV( |
| 220 | + ... estimator=lr, param_grid=param_grid |
| 221 | + ... ).fit(X, y, sample_weight=my_weights) |
| 222 | + ... except ValueError as e: |
| 223 | + ... print(e) |
| 224 | + [sample_weight] are passed but are not explicitly set as requested or not for |
| 225 | + LogisticRegression.score |
| 226 | + |
| 227 | +The issue can be fixed by explicitly setting the request value:: |
| 228 | + |
| 229 | + >>> lr = LogisticRegression().set_fit_request( |
| 230 | + ... sample_weight=True |
| 231 | + ... ).set_score_request(sample_weight=False) |
0 commit comments