8000 FEAT SLEP006: metadata routing infrastructure (#24027) · REDVM/scikit-learn@db84357 · GitHub
[go: up one dir, main page]

Skip to content

Commit db84357

Browse files
adrinjalalilorentzenchrogriselBenjaminBossanthomasjpfan
authored andcommitted
FEAT SLEP006: metadata routing infrastructure (scikit-learn#24027)
Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com> Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent ea05254 commit db84357

28 files changed

+4396
-155
lines changed

doc/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ def pytest_runtest_setup(item):
144144
setup_preprocessing()
145145
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
146146
setup_unsupervised_learning()
147+
elif fname.endswith("metadata_routing.rst"):
148+
# TODO: remove this once implemented
149+
# Skip metarouting because is it is not fully implemented yet
150+
raise SkipTest(
151+
"Skipping doctest for metadata_routing.rst because it "
152+
"is not fully implemented yet"
153+
)
147154

148155
rst_files_requiring_matplotlib = [
149156
"modules/partial_dependence.rst",

doc/metadata_routing.rst

Lines changed: 231 additions & 0 deletions
< B47E tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
2+
.. _metadata_routing:
3+
4+
.. currentmodule:: sklearn
5+
6+
.. TODO: update doc/conftest.py once document is updated and examples run.
7+
8+
Metadata Routing
9+
================
10+
11+
.. note::
12+
The Metadata Routing API is experimental, and is not implemented yet for many
13+
estimators. It may change without the usual deprecation cycle. By default
14+
this feature is not enabled. You can enable this feature by setting the
15+
``enable_metadata_routing`` flag to ``True``:
16+
17+
>>> import sklearn
18+
>>> sklearn.set_config(enable_metadata_routing=True)
19+
20+
This guide demonstrates how metadata such as ``sample_weight`` can be routed
21+
and passed along to estimators, scorers, and CV splitters through
22+
meta-estimators such as :class:`~pipeline.Pipeline` and
23+
:class:`~model_selection.GridSearchCV`. In order to pass metadata to a method
24+
such as ``fit`` or ``score``, the object consuming the metadata, must *request*
25+
it. For estimators and splitters, this is done via ``set_*_request`` methods,
26+
e.g. ``set_fit_request(...)``, and for scorers this is done via the
27+
``set_score_request`` method. For grouped splitters such as
28+
:class:`~model_selection.GroupKFold`, a ``groups`` parameter is requested by
29+
default. This is best demonstrated by the following examples.
30+
31+
If you are developing a scikit-learn compatible estimator or meta-estimator,
32+
you can check our related developer guide:
33+
:ref:`sphx_glr_auto_examples_miscellaneous_plot_metadata_routing.py`.
34+
35+
.. note::
36+
Note that the methods and requirements introduced in this document are only
37+
relevant if you want to pass metadata (e.g. ``sample_weight``) to a method.
38+
If you're only passing ``X`` and ``y`` and no other parameter / metadata to
39+
methods such as ``fit``, ``transform``, etc, then you don't need to set
40+
anything.
41+
42+
Usage Examples
43+
**************
44+
Here we present a few examples to show different common use-cases. The examples
45+
in this section require the following imports and data::
46+
47+
>>> import numpy as np
48+
>>> from sklearn.metrics import make_scorer, accuracy_score
49+
>>> from sklearn.linear_model import LogisticRegressionCV, LogisticRegression
50+
>>> from sklearn.model_selection import cross_validate, GridSearchCV, GroupKFold
51+
>>> from sklearn.feature_selection import SelectKBest
52+
>>> from sklearn.pipeline import make_pipeline
53+
>>> n_samples, n_features = 100, 4
54+
>>> rng = np.random.RandomState(42)
55+
>>> X = rng.rand(n_samples, n_features)
56+
>>> y = rng.randint(0, 2, size=n_samples)
57+
>>> my_groups = rng.randint(0, 10, size=n_samples)
58+
>>> my_weights = rng.rand(n_samples)
59+
>>> my_other_weights = rng.rand(n_samples)
60+
61+
Weighted scoring and fitting
62+
----------------------------
63+
64+
Here :class:`~model_selection.GroupKFold` requests ``groups`` by default. However, we
65+
need to explicitly request weights for our scorer and the internal cross validation of
66+
:class:`~linear_model.LogisticRegressionCV`. Both of these *consumers* know how to use
67+
metadata called ``sample_weight``::
68+
69+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
70+
... sample_weight=True
71+
... )
72+
>>> lr = LogisticRegressionCV(
73+
... cv=GroupKFold(), scoring=weighted_acc,
74+
... ).set_fit_request(sample_weight=True)
75+
>>> cv_results = cross_validate(
76+
... lr,
77+
... X,
78+
... y,
79+
... props={"sample_weight": my_weights, "groups": my_groups},
80+
... cv=GroupKFold(),
81+
... scoring=weighted_acc,
82+
... )
83+
84+
Note that in this example, ``my_weights`` is passed to both the scorer and
85+
:class:`~linear_model.LogisticRegressionCV`.
86+
87+
Error handling: if ``props={"sample_weigh": my_weights, ...}`` were passed
88+
(note the typo), :func:`~model_selection.cross_validate` would raise an error,
89+
since ``sample_weigh`` was not requested by any of its underlying objects.
90+
91+
Weighted scoring and unweighted fitting
92+
---------------------------------------
93+
94+
When passing metadata such as ``sample_weight`` around, all scikit-learn
95+
estimators require weights to be either explicitly requested or not requested
96+
(i.e. ``True`` or ``False``) when used in another router such as a
97+
:class:`~pipeline.Pipeline` or a ``*GridSearchCV``. To perform an unweighted
98+
fit, we need to configure :class:`~linear_model.LogisticRegressionCV` to not
99+
request sample weights, so that :func:`~model_selection.cross_validate` does
100+
not pass the weights along::
101+
102+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
103+
... sample_weight=True
104+
... )
105+
>>> lr = LogisticRegressionCV(
106+
... cv=GroupKFold(), scoring=weighted_acc,
107+
... ).set_fit_request(sample_weight=False)
108+
>>> cv_results = cross_validate(
109+
... lr,
110+
... X,
111+
... y,
112+
... cv=GroupKFold(),
113+
... props={"sample_weight": my_weights, "groups": my_groups},
114+
... scoring=weighted_acc,
115+
... )
116+
117+
If :meth:`linear_model.LogisticRegressionCV.set_fit_request` has not
118+
been called, :func:`~model_selection.cross_validate` will raise an
119+
error because ``sample_weight`` is passed in but
120+
:class:`~linear_model.LogisticRegressionCV` would not be explicitly configured
121+
to recognize the weights.
122+
123+
Unweighted feature selection
124+
----------------------------
125+
126+
Setting request values for metadata are only required if the object, e.g. estimator,
127+
scorer, etc., is a consumer of that metadata Unlike
128+
:class:`~linear_model.LogisticRegressionCV`, :class:`~feature_selection.SelectKBest`
129+
doesn't consume weights and therefore no request value for ``sample_weight`` on its
130+
instance is set and ``sample_weight`` is not routed to it::
131+
132+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
133+
... sample_weight=True
134+
... )
135+
>>> lr = LogisticRegressionCV(
136+
... cv=GroupKFold(), scoring=weighted_acc,
137+
... ).set_fit_request(sample_weight=True)
138+
>>> sel = SelectKBest(k=2)
139+
>>> pipe = make_pipeline(sel, lr)
140+
>>> cv_results = cross_validate(
141+
... pipe,
142+
... X,
143+
... y,
144+
... cv=GroupKFold(),
145+
... props={"sample_weight": my_weights, "groups": my_groups},
146+
... scoring=weighted_acc,
147+
... )
148+
149+
Advanced: Different scoring and fitting weights
150+
-----------------------------------------------
151+
152+
Despite :func:`~metrics.make_scorer` and
153+
:class:`~linear_model.LogisticRegressionCV` both expecting the key
154+
``sample_weight``, we can use aliases to pass different weights to different
155+
consumers. In this example, we pass ``scoring_weight`` to the scorer, and
156+
``fitting_weight`` to :class:`~linear_model.LogisticRegressionCV`::
157+
158+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
159+
... sample_weight="scoring_weight"
160+
... )
161+
>>> lr = LogisticRegressionCV(
162+
... cv=GroupKFold(), scoring=weighted_acc,
163+
... ).set_fit_request(sample_weight="fitting_weight")
164+
>>> cv_results = cross_validate(
165+
... lr,
166+
... X,
167+
... y,
168+
... cv=GroupKFold(),
169+
... props={
170+
... "scoring_weight": my_weights,
171+
... "fitting_weight": my_other_weights,
172+
... "groups": my_groups,
173+
... },
174+
... scoring=weighted_acc,
175+
... )
176+
177+
API Interface
178+
*************
179+
180+
A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which
181+
accepts and uses some metadata in at least one of its methods (``fit``,
182+
``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``).
183+
Meta-estimators which only forward the metadata to other objects (the child
184+
estimator, scorers, or splitters) and don't use the metadata themselves are not
185+
consumers. (Meta-)Estimators which route metadata to other objects are
186+
*routers*. A(n) (meta-)estimator can be a consumer and a router at the same time.
187+
(Meta-)Estimators and splitters expose a ``set_*_request`` method for each
188+
method which accepts at least one metadata. For instance, if an estimator
189+
supports ``sample_weight`` in ``fit`` and ``score``, it exposes
190+
``estimator.set_fit_request(sample_weight=value)`` and
191+
``estimator.set_score_request(sample_weight=value)``. Here ``value`` can be:
192+
193+
- ``True``: method requests a ``sample_weight``. This means if the metadata is
194+
provided, it will be used, otherwise no error is raised.
195+
- ``False``: method does not request a ``sample_weight``.
196+
- ``None``: router will raise an error if ``sample_weight`` is passed. This is
197+
in almost all cases the default value when an object is instantiated and
198+
ensures the user sets the metadata requests explicitly when a metadata is
199+
passed. The only exception are ``Group*Fold`` splitters.
200+
- ``"param_name"``: if this estimator is used in a meta-estimator, the
201+
meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this
202+
estimator. This means the mapping between the metadata required by the
203+
object, e.g. ``sample_weight`` and what is provided by the user, e.g.
204+
``my_weights`` is done at the router level, and not by the object, e.g.
205+
estimator, itself.
206+
207+
Metadata are requested in the same way for scorers using ``set_score_request``.
208+
209+
If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata
210+
request for all objects which potentially can consume ``sample_weight`` should
211+
be set by the user, otherwise an error is raised by the router object. For
212+
example, the following code raises an error, since it hasn't been explicitly
213+
specified whether ``sample_weight`` should be passed to the estimator's scorer
214+
or not::
215+
216+
>>> param_grid = {"C": [0.1, 1]}
217+
>>> lr = LogisticRegression().set_fit_request(sample_weight=True)
218+
>>> try:
219+
... GridSearchCV(
220+
... estimator=lr, param_grid=param_grid
221+
... ).fit(X, y, sample_weight=my_weights)
222+
... except ValueError as e:
223+
... print(e)
224+
[sample_weight] are passed but are not explicitly set as requested or not for
225+
LogisticRegression.score
226+
227+
The issue can be fixed by explicitly setting the request value::
228+
229+
>>> lr = LogisticRegression().set_fit_request(
230+
... sample_weight=True
231+
... ).set_score_request(sample_weight=False)

doc/modules/classes.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Base classes
3434
base.DensityMixin
3535
base.RegressorMixin
3636
base.TransformerMixin
37+
base.MetaEstimatorMixin
3738
base.OneToOneFeatureMixin
3839
base.ClassNamePrefixFeaturesOutMixin
3940
feature_selection.SelectorMixin
@@ -1652,6 +1653,11 @@ Plotting
16521653
utils.validation.check_symmetric
16531654
utils.validation.column_or_1d
16541655
utils.validation.has_fit_parameter
1656+
utils.metadata_routing.get_routing_for_object
1657+
utils.metadata_routing.MetadataRouter
1658+
utils.metadata_routing.MetadataRequest
1659+
utils.metadata_routing.MethodMapping
1660+
utils.metadata_routing.process_routing
16551661

16561662
Specific utilities to list scikit-learn components:
16571663

doc/modules/model_evaluation.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,14 @@ the following two rules:
222222
Again, by convention higher numbers are better, so if your scorer
223223
returns loss, that value should be negated.
224224

225+
- Advanced: If it requires extra metadata to be passed to it, it should expose
226+
a ``get_metadata_routing`` method returning the requested metadata. The user
227+
should be able to set the requested metadata via a ``set_score_request``
228+
method. Please see :ref:`User Guide <metadata_routing>` and :ref:`Developer
229+
Guide <sphx_glr_auto_examples_miscellaneous_plot_metadata_routing.py>` for
230+
more details.
231+
232+
225233
.. note:: **Using custom scorers in functions where n_jobs > 1**
226234

227235
While defining the custom scoring function alongside the calling function

doc/user_guide.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,12 @@ User Guide
3131
model_persistence.rst
3232
common_pitfalls.rst
3333
dispatching.rst
34+
35+
Under Development
36+
-----------------
37+
38+
.. toctree::
39+
:numbered:
40+
:maxdepth: 1
41+
42+
metadata_routing.rst

doc/whats_new/v1.3.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,19 @@ Changes impacting all modules
145145
:pr:`26082` by :user:`Jérémie du Boisberranger <jeremiedbb>` and
146146
:user:`Olivier Grisel <ogrisel>`.
147147

148+
Experimental / Under Development
149+
--------------------------------
150+
151+
- |MajorFeature| :ref:`Metadata routing <metadata_routing>`'s related base
152+
methods are included in this release. This feature is only available via the
153+
`enable_metadata_routing` feature flag which can be enabled using
154+
:func:`sklearn.set_config` and :func:`sklearn.config_context`. For now this
155+
feature is mostly useful for third party developers to prepare their code
156+
base for metadata routing, and we strongly recommend that they also hide it
157+
behind the same feature flag, rather than having it enabled by default.
158+
:pr:`24027` by `Adrin Jalali`_, :user:`Benjamin Bossan <BenjaminBossan>`, and
159+
:user:`Omar Salman <OmarManzoor>`.
160+
148161
Changelog
149162
---------
150163

0 commit comments

Comments
 (0)
0