8000 FIX `RecursionError` bug with metadata routing in metaestimators with… · scikit-learn/scikit-learn@78675d1 · GitHub
[go: up one dir, main page]

Skip to content

Commit 78675d1

Browse files
FIX RecursionError bug with metadata routing in metaestimators with scoring (#28712)
LGTM. Thanks @StefanieSenger , @adrinjalali
1 parent 153e796 commit 78675d1

File tree

5 files changed

+100
-31
lines changed

5 files changed

+100
-31
lines changed

doc/whats_new/v1.5.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,10 @@ more details.
131131
methods, which can happen if one tries to decorate them.
132132
:pr:`28651` by `Adrin Jalali`_.
133133

134+
- |FIX| Prevent a `RecursionError` when estimators with the default `scoring`
135+
param (`None`) route metadata. :pr:`28712` by :user:`Stefanie Senger
136+
<StefanieSenger>`.
137+
134138
Changelog
135139
---------
136140

sklearn/linear_model/_ridge.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,12 +2394,10 @@ class RidgeCV(MultiOutputMixin, RegressorMixin, _BaseRidgeCV):
23942394
(i.e. data is expected to be centered).
23952395
23962396
scoring : str, callable, default=None
2397-
A string (see model evaluation documentation) or
2398-
a scorer callable object / function with signature
2399-
``scorer(estimator, X, y)``.
2400-
If None, the negative mean squared error if cv is 'auto' or None
2401-
(i.e. when using leave-one-out cross-validation), and r2 score
2402-
otherwise.
2397+
A string (see :ref:`scoring_parameter`) or a scorer callable object /
2398+
function with signature ``scorer(estimator, X, y)``. If None, the
2399+
negative mean squared error if cv is 'auto' or None (i.e. when using
2400+
leave-one-out cross-validation), and r2 score otherwise.
24032401
24042402
cv : int, cross-validation generator or an iterable, default=None
24052403
Determines the cross-validation splitting strategy.
@@ -2570,9 +2568,8 @@ class RidgeClassifierCV(_RidgeClassifierMixin, _BaseRidgeCV):
25702568
(i.e. data is expected to be centered).
25712569
25722570
scoring : str, callable, default=None
2573-
A string (see model evaluation documentation) or
2574-
a scorer callable object / function with signature
2575-
``scorer(estimator, X, y)``.
2571+
A string (see :ref:`scoring_parameter`) or a scorer callable object /
2572+
function with signature ``scorer(estimator, X, y)``.
25762573
25772574
cv : int, cross-validation generator or an iterable, default=None
25782575
Determines the cross-validation splitting strategy.

sklearn/linear_model/tests/test_ridge.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,3 +2085,20 @@ def test_ridge_sample_weight_consistency(
20852085
assert_allclose(reg1.coef_, reg2.coef_)
20862086
if fit_intercept:
20872087
assert_allclose(reg1.intercept_, reg2.intercept_)
2088+
2089+
2090+
# Metadata Routing Tests
2091+
# ======================
2092+
2093+
2094+
@pytest.mark.usefixtures("enable_slep006")
2095+
@pytest.mark.parametrize("metaestimator", [RidgeCV, RidgeClassifierCV])
2096+
def test_metadata_routing_with_default_scoring(metaestimator):
2097+
"""Test that `RidgeCV` or `RidgeClassifierCV` with default `scoring`
2098+
argument (`None`), don't enter into `RecursionError` when metadata is routed.
2099+
"""
2100+
metaestimator().get_metadata_routing()
2101+
2102+
2103+
# End of Metadata Routing Tests
2104+
# =============================

sklearn/metrics/_scorer.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -411,14 +411,31 @@ def get_scorer(scoring):
411411
return scorer
412412

413413

414-
class _PassthroughScorer:
414+
class _PassthroughScorer(_MetadataRequester):
415+
# Passes scoring of estimator's `score` method back to estimator if scoring
416+
# is `None`.
417+
415418
def __init__(self, estimator):
416419
self._estimator = estimator
417420

421+
requests = MetadataRequest(owner=self.__class__.__name__)
422+
try:
423+
requests.score = copy.deepcopy(estimator._metadata_request.score)
424+
except AttributeError:
425+
try:
426+
requests.score = copy.deepcopy(estimator._get_default_requests().score)
427+
except AttributeError:
428+
pass
429+
430+
self._metadata_request = requests
431+
418432
def __call__(self, estimator, *args, **kwargs):
419433
"""Method that wraps estimator.score"""
420434
return estimator.score(*args, **kwargs)
421435

436+
def __repr__(self):
437+
return f"{self._estimator.__class__}.score"
438+
422439
def get_metadata_routing(self):
423440
"""Get requested data properties.
424441
@@ -433,13 +450,32 @@ def get_metadata_routing(self):
433450
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
434451
routing information.
435452
"""
436-
# This scorer doesn't do any validation or routing, it only exposes the
437-
# requests of the given estimator. This object behaves as a consumer
438-
# rather than a router. Ideally it only exposes the score requests to
439-
# the parent object; however, that requires computing the routing for
440-
# meta-estimators, which would be more time consuming than simply
441-
# returning the child object's requests.
442-
return get_routing_for_object(self._estimator)
453+
return get_routing_for_object(self._metadata_request)
454+
455+
def set_score_request(self, **kwargs):
456+
"""Set requested parameters by the scorer.
457+
458+
Please see :ref:`User Guide <metadata_routing>` on how the routing
459+
mechanism works.
460+
461+
.. versionadded:: 1.5
462+
463+
Parameters
464+
----------
465+
kwargs : dict
466+
Arguments should be of the form ``param_name=alias``, and `alias`
467+
can be one of ``{True, False, None, str}``.
468+
"""
469+
if not _routing_enabled():
470+
raise RuntimeError(
471+
"This method is only available when metadata routing is enabled."
472+
" You can enable it using"
473+
" sklearn.set_config(enable_metadata_routing=True)."
474+
)
475+
476+
for param, alias in kwargs.items():
477+
self._metadata_request.score.add_request(param=param, alias=alias)
478+
return self
443479

444480

445481
def _check_multimetric_scoring(estimator, scoring):

sklearn/metrics/tests/test_score_objects.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from sklearn.pipeline import make_pipeline
5454
from sklearn.svm import LinearSVC
5555
from sklearn.tests.metadata_routing_common import (
56-
assert_request_equal,
5756
assert_request_is_empty,
5857
)
5958
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
@@ -1276,24 +1275,40 @@ def test_metadata_kwarg_conflict():
12761275

12771276

12781277
@pytest.mark.usefixtures("enable_slep006")
1279-
def test_PassthroughScorer_metadata_request():
1280-
"""Test that _PassthroughScorer properly routes metadata.
1278+
def test_PassthroughScorer_set_score_request():
1279+
"""Test that _PassthroughScorer.set_score_request adds the correct metadata request
1280+
on itself and doesn't change its estimator's routing."""
1281+
est = LogisticRegression().set_score_request(sample_weight="estimator_weights")
1282+
# make a `_PassthroughScorer` with `check_scoring`:
1283+
scorer = check_scoring(est, None)
1284+
assert (
1285+
scorer.get_metadata_routing().score.requests["sample_weight"]
1286+
== "estimator_weights"
1287+
)
12811288

1282-
_PassthroughScorer should behave like a consumer, mirroring whatever is the
1283-
underlying score method.
1284-
"""
1285-
scorer = _PassthroughScorer(
1286-
estimator=LinearSVC()
1287-
.set_score_request(sample_weight="alias")
1288-
.set_fit_request(sample_weight=True)
1289+
scorer.set_score_request(sample_weight="scorer_weights")
1290+
assert (
1291+
scorer.get_metadata_routing().score.requests["sample_weight"]
1292+
== "scorer_weights"
12891293
)
1290-
# Test that _PassthroughScorer doesn't change estimator's routing.
1291-
assert_request_equal(
1292-
scorer.get_metadata_routing(),
1293-
{"fit": {"sample_weight": True}, "score": {"sample_weight": "alias"}},
1294+
1295+
# making sure changing the passthrough object doesn't affect the estimator.
1296+
assert (
1297+
est.get_metadata_routing().score.requests["sample_weight"]
1298+
== "estimator_weights"
12941299
)
12951300

12961301

1302+
def test_PassthroughScorer_set_score_request_raises_without_routing_enabled():
1303+
"""Test that _PassthroughScorer.set_score_request raises if metadata routing is
1304+
disabled."""
1305+
scorer = check_scoring(LogisticRegression(), None)
1306+
msg = "This method is only available when metadata routing is enabled."
1307+
1308+
with pytest.raises(RuntimeError, match=msg):
1309+
scorer.set_score_request(sample_weight="my_weights")
1310+
1311+
12971312
@pytest.mark.usefixtures("enable_slep006")
12981313
def test_multimetric_scoring_metadata_routing():
12991314
# Test that _MultimetricScorer properly routes metadata.

0 commit comments

Comments
 (0)
0