8000 FEAT Base sample-prop implementation and docs (#22083) · scikit-learn/scikit-learn@0b298ed · GitHub
[go: up one dir, main page]

Skip to content

Commit 0b298ed

Browse files
adrinjalalilorentzenchrogrisel
authored
FEAT Base sample-prop implementation and docs (#22083)
* initial base implementation commit * fix test_props and the issue with attribute starting with __ * skip doctest in metadata_routing.rst for now * DOC explain why aliasing on sub-estimator of a consumer/router is useful * reduce diff * DOC add user guide link to method docstrings * DOC apply Thomas's suggestions to the rst file * CLN address a few comments in docs * ignore sentinel docstring check * handling backward compatibility and deprecation prototype * Update examples/plot_metadata_routing.py Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> * make __metadata_request__* format more intuitive and less redundant * metadata_request_factory always returns a copy * fix tests for the changed __metadata_request__* format * in example: foo->sample_weight, bar->groups * get_method_input->get_input * minor comments from Guillaume * fix estimator checks tests * Improved sample props developer API * fixes, updated doc, decorator * Add docstrings and some API cleanup * unify serialize/deserialize methods * Add more docstring to process_routing * fix MetadataRouter.get_params parameter mismatch * DOC add missing name to MethodMetadataRequest.deserialize docstring * DOC add MethodMapping.add docstring * DOC fix colons after versionadded * fix {method}_requests return type annotation * metadata_request_factory -> metadata_router_factory and docstring fixes * move 'me' out of the map in MetadataRouter * more docstring refinements * cleanup API addresses and create a utils.metadata_routing sub-folder * fix module import issue * more tests and a few bug fixes * Joel's comments * make process_routing a function * docstring fix * ^type -> $type * remove deserialize, return instance, and add type as an attribute * remove sentinels and use strings instead * make RequestType searchable and check for valid identifier * Route -> MethodPair * remove unnecessary sorted * clarification on usage of the process_routing func in the example * only print methods with non-empty requests * fix test_string_representations * remove source build cache from CircleCI (temporarily) * Trigger CI * Invalidate linux-arm64 ccache my changing the key * Trigger CI * method, used_in -> callee, caller * show RequestType instead of RequestType.value in _serialize() * more informative error messages * fix checking for conflicting keys * get_router_for_object -> get_routing_for_object * \{method\}_requests -> set_\{method\}_request * address metadata_routing.rst comments * some test enhancements * TypeError for extra arguments * add_request: prop -> param * original_names -> return_alias * add more tests for MetadataRouter and MethodMapping * more suggestions from Joel's review * fix return type * apply more suggestions from Joel's review * Christian\'s suggestions * more notes from Christian * test_get_routing_for_object returns empty requests on unknown objects * more notes from Christian * remove double line break * more notes from Christian * more notes from Christian * make type private * add more comments/docs * fix test * fix nits * add forgotten nit Co-authored-by: Christian Lorentzen <lorentzen.ch@gmail.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 98a9ae0 commit 0b298ed

13 files changed

+2920
-3
lines changed

.circleci/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,10 @@ jobs:
132132
- checkout
133133
- run: ./build_tools/circle/checkout_merge_commit.sh
134134
- restore_cache:
135-
key: linux-arm64-{{ .Branch }}
135+
key: linux-arm64-ccache-v1-{{ .Branch }}
136136
- run: ./build_tools/circle/build_test_arm.sh
137137
- save_cache:
138-
key: linux-arm64-{{ .Branch }}
138+
key: linux-arm64-ccache-v1-{{ .Branch }}
139139
paths:
140140
- ~/.cache/ccache
141141
- ~/.cache/pip

doc/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ def pytest_runtest_setup(item):
138138
setup_preprocessing()
139139
elif fname.endswith("statistical_inference/unsupervised_learning.rst"):
140140
setup_unsupervised_learning()
141+
elif fname.endswith("metadata_routing.rst"):
142+
# TODO: remove this once implemented
143+
# Skip metarouting because is it is not fully implemented yet
144+
raise SkipTest(
145+
"Skipping doctest for metadata_routing.rst because it "
146+
"is not fully implemented yet"
147+
)
141148

142149
rst_files_requiring_matplotlib = [
143150
"modules/partial_dependence.rst",

doc/metadata_routing.rst

Lines changed: 215 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
2+
.. _metadata_routing:
3+
4+
.. TODO: update doc/conftest.py once document is updated and examples run.
5+
6+
Metadata Routing
7+
================
8+
9+
This guide demonstrates how metadata such as ``sample_weight`` can be routed
10+
and passed along to estimators, scorers, and CV splitters through
11+
meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass
12+
metadata to a method such as ``fit`` or ``score``, the object accepting the
13+
metadata, must *request* it. For estimators and splitters this is done via
14+
``set_*_request`` methods, e.g. ``set_fit_request(...)``, and for scorers this
15+
is done via ``set_score_request`` method. For grouped splitters such as
16+
``GroupKFold`` a ``groups`` parameter is requested by default. This is best
17+
demonstrated by the following examples.
18+
19+
If you are developing a scikit-learn compatible estimator or meta-estimator,
20+
you can check our related developer guide:
21+
:ref:`sphx_glr_auto_examples_plot_metadata_routing.py`.
22+
23+
Usage Examples
24+
**************
25+
Here we present a few examples to show different common use-cases. The examples
26+
in this section require the following imports and data::
27+
28+
>>> import numpy as np
29+
>>> from sklearn.metrics import make_scorer, accuracy_score
30+
>>> from sklearn.linear_model import LogisticRegressionCV
31+
>>> from sklearn.linear_model import LogisticRegression
32+
>>> from sklearn.model_selection import cross_validate
33+
>>> from sklearn.model_selection import GridSearchCV
34+
>>> from sklearn.model_selection import GroupKFold
35+
>>> from sklearn.feature_selection import SelectKBest
36+
>>> from sklearn.utils.metadata_requests import RequestType
37+
>>> from sklearn.pipeline import make_pipeline
38+
>>> n_samples, n_features = 100, 4
39+
>>> X = np.random.rand(n_samples, n_features)
40+
>>> y = np.random.randint(0, 2, size=n_samples)
41+
>>> my_groups = np.random.randint(0, 10, size=n_samples)
42+
>>> my_weights = np.random.rand(n_samples)
43+
>>> my_other_weights = np.random.rand(n_samples)
44+
45+
Weighted scoring and fitting
46+
----------------------------
47+
48+
Here ``GroupKFold`` requests ``groups`` by default. However, we need to
49+
explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``.
50+
Both of these *consumers* know how to use metadata called ``"sample_weight"``::
51+
52+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
53+
... sample_weight=True
54+
... )
55+
>>> lr = LogisticRegressionCV(
56+
... cv=GroupKFold(), scoring=weighted_acc,
57+
... ).set_fit_request(sample_weight=True)
58+
>>> cv_results = cross_validate(
59+
... lr,
60+
... X,
61+
... y,
62+
... cv=GroupKFold(),
63+
... props={"sample_weight": my_weights, "groups": my_groups},
64+
... scoring=weighted_acc,
65+
... )
66+
67+
Note that in this example, ``my_weights`` is passed to both the scorer and
68+
:class:`~linear_model.LogisticRegressionCV`.
69+
70+
Error handling: if ``props={"sample_weigh": my_weights, ...}`` were passed
71+
(note the typo), ``cross_validate`` would raise an error, since
72+
``sample_weigh`` was not requested by any of its children.
73+
74+
Weighted scoring and unweighted fitting
75+
---------------------------------------
76+
77+
All scikit-learn estimators requires weights to be either explicitly requested
78+
or not requested (i.e. ``UNREQUESTED``) when used in another router such as a
79+
``Pipeline`` or a ``*GridSearchCV``. To perform a unweighted fit, we need to
80+
configure :class:`~linear_model.LogisticRegressionCV` to not request sample
81+
weights, so that :func:`~model_selection.cross_validate` does not pass the
82+
weights along::
83+
84+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
85+
... sample_weight=True
86+
... )
87+
>>> lr = LogisticRegressionCV(
88+
... cv=GroupKFold(), scoring=weighted_acc,
89+
... ).set_fit_request(sample_weight=RequestType.UNREQUESTED)
90+
>>> cv_results = cross_validate(
91+
... lr,
92+
... X,
93+
... y,
94+
... cv=GroupKFold(),
95+
... props={"sample_weight": my_weights, "groups": my_groups},
96+
... scoring=weighted_acc,
97+
... )
98+
99+
Note the usage of ``RequestType`` which in this case is equivalent to
100+
``False``; the type is explained further at the end of this document.
101+
102+
If :class:`~linear_model.LogisticRegressionCV` does not call
103+
``set_fit_request``, :func:`~model_selection.cross_validate` will raise an
104+
error because weights is passed in but
105+
:class:`~linear_model.LogisticRegressionCV` would not be explicitly configured
106+
to recognize the weights.
107+
108+
Unweighted feature selection
109+
----------------------------
110+
111+
Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and
112+
therefore `"sample_weight"` is not routed to it::
113+
114+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
115+
... sample_weight=True
116+
... )
117+
>>> lr = LogisticRegressionCV(
118+
... cv=GroupKFold(), scoring=weighted_acc,
119+
... ).set_fit_request(sample_weight=True)
120+
>>> sel = SelectKBest(k=2)
121+
>>> pipe = make_pipeline(sel, lr)
122+
>>> cv_results = cross_validate(
123+
... pipe,
124+
... X,
125+
... y,
126+
... cv=GroupKFold(),
127+
... props={"sample_weight": my_weights, "groups": my_groups},
128+
... scoring=weighted_acc,
129+
... )
130+
131+
Advanced: Different scoring and fitting weights
132+
-----------------------------------------------
133+
134+
Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting the key
135+
``sample_weight``, we can use aliases to pass different weights to different
136+
consumers. In this example, we pass ``scoring_weight`` to the scorer, and
137+
``fitting_weight`` to ``LogisticRegressionCV``::
138+
139+
>>> weighted_acc = make_scorer(accuracy_score).set_score_request(
140+
... sample_weight="scoring_weight"
141+
... )
142+
>>> lr = LogisticRegressionCV(
143+
... cv=GroupKFold(), scoring=weighted_acc,
144+
... ).set_fit_request(sample_weight="fitting_weight")
145+
>>> cv_results = cross_validate(
146+
... lr,
147+
... X,
148+
... y,
149+
... cv=GroupKFold(),
150+
... props={
151+
... "scoring_weight": my_weights,
152+
... "fitting_weight": my_other_weights,
153+
... "groups": my_groups,
154+
... },
155+
... scoring=weighted_acc,
156+
... )
157+
158+
API Interface
159+
*************
160+
161+
A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which
162+
accepts and uses some metadata in at least one of its methods (``fit``,
163+
``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``).
164+
Meta-estimators which only forward the metadata to other objects (the child
165+
estimator, scorers, or splitters) and don't use the metadata themselves are not
166+
consumers. (Meta)Estimators which route metadata to other objects are
167+
*routers*. An (meta)estimator can be a consumer and a router at the same time.
168+
(Meta)Estimators and splitters expose a ``set_*_request`` method for each
169+
method which accepts at least one metadata. For instance, if an estimator
170+
supports ``sample_weight`` in ``fit`` and ``score``, it exposes
171+
``estimator.set_fit_request(sample_weight=value)`` and
172+
``estimator.set_score_request(sample_weight=value)``. Here ``value`` can be:
173+
174+
- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``.
175+
This means if the metadata is provided, it will be used, otherwise no error
176+
is raised.
177+
- ``RequestType.UNREQUESTED`` or ``False``: method does not request a
178+
``sample_weight``.
179+
- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if
180+
``sample_weight`` is passed. This is in almost all cases the default value
181+
when an object is instantiated and ensures the user sets the metadata
182+
requests explicitly when a metadata is passed. The only exception are
183+
``Group*Fold`` splitters.
184+
- ``"param_name"``: if this estimator is used in a meta-estimator, the
185+
meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this
186+
estimator. This means the mapping between the metadata required by the
187+
object, e.g. ``sample_weight`` and what is provided by the user, e.g.
188+
``my_weights`` is done at the router level, and not by the object, e.g.
189+
estimator, itself.
190+
191+
For the scorers, this is done the same way, using ``set_score_request`` method.
192+
193+
If a metadata, e.g. ``sample_weight``, is passed by the user, the metadata
194+
request for all objects which potentially can accept ``sample_weight`` should
195+
be set by the user, otherwise an error is raised by the router object. For
196+
example, the following code raises an error, since it hasn't been explicitly
197+
specified whether ``sample_weight`` should be passed to the estimator's scorer
198+
or not::
199+
200+
>>> param_grid = {"C": [0.1, 1]}
201+
>>> lr = LogisticRegression().set_fit_request(sample_weight=True)
202+
>>> try:
203+
... GridSearchCV(
204+
... estimator=lr, param_grid=param_grid
205+
... ).fit(X, y, sample_weight=my_weights)
206+
... except ValueError as e:
207+
... print(e)
208+
sample_weight is passed but is not explicitly set as requested or not for
209+
LogisticRegression.score
210+
211+
The issue can be fixed by explicitly setting the request value::
212+
213+
>>> lr = LogisticRegression().set_fit_request(
214+
... sample_weight=True
215+
... ).set_score_request(sample_weight=False)

doc/modules/classes.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ Base classes
3434
base.DensityMixin
3535
base.RegressorMixin
3636
base.TransformerMixin
37+
base.MetaEstimatorMixin
3738
feature_selection.SelectorMixin
3839

3940
Functions
@@ -1640,6 +1641,12 @@ Plotting
16401641
utils.validation.column_or_1d
16411642
utils.validation.has_fit_parameter
16421643
utils.all_estimators
1644+
utils.metadata_routing.RequestType
1645+
utils.metadata_routing.get_routing_for_object
1646+
utils.metadata_routing.MetadataRouter
1647+
utils.metadata_routing.MetadataRequest
1648+
utils.metadata_routing.MethodMapping
1649+
utils.metadata_routing.process_routing
16431650

16441651
Utilities from joblib:
16451652

doc/user_guide.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ User Guide
2727
visualizations.rst
2828
data_transforms.rst
2929
datasets.rst
30+
metadata_routing.rst
3031
computing.rst
3132
model_persistence.rst
3233
common_pitfalls.rst

0 commit comments

Comments
 (0)
0