8000 CLN Support _MultimetricScorer directly for internal methods (#28359) · scikit-learn/scikit-learn@cf1fb22 · GitHub
[go: up one dir, main page]

Skip to content

Commit cf1fb22

Browse files
authored
CLN Support _MultimetricScorer directly for internal methods (#28359)
1 parent b0edfdf commit cf1fb22

File tree

3 files changed

+24
-38
lines changed

3 files changed

+24
-38
lines changed

sklearn/model_selection/_search.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -781,18 +781,11 @@ def _select_best_index(refit, refit_metric, results):
781781
best_index = results[f"rank_test_{refit_metric}"].argmin()
782782
return best_index
783783

784-
def _get_scorers(self, convert_multimetric):
784+
def _get_scorers(self):
785785
"""Get the scorer(s) to be used.
786786
787787
This is used in ``fit`` and ``get_metadata_routing``.
788788
789-
Parameters
790-
----------
791-
convert_multimetric : bool
792-
Whether to convert a dict of scorers to a _MultimetricScorer. This
793-
is used in ``get_metadata_routing`` to include the routing info for
794-
multiple scorers.
795-
796789
Returns
797790
-------
798791
scorers, refit_metric
@@ -807,10 +800,9 @@ def _get_scorers(self, convert_multimetric):
807800
scorers = _check_multimetric_scoring(self.estimator, self.scoring)
808801
self._check_refit_for_multimetric(scorers)
809802
refit_metric = self.refit
810-
if convert_multimetric and isinstance(scorers, dict):
811-
scorers = _MultimetricScorer(
812-
scorers=scorers, raise_exc=(self.error_score == "raise")
813-
)
803+
scorers = _MultimetricScorer(
804+
scorers=scorers, raise_exc=(self.error_score == "raise")
805+
)
814806

815807
return scorers, refit_metric
816808

@@ -866,10 +858,7 @@ def fit(self, X, y=None, **params):
866858
Instance of fitted estimator.
867859
"""
868860
estimator = self.estimator
869-
# Here we keep a dict of scorers as is, and only convert to a
870-
# _MultimetricScorer at a later stage. Issue:
871-
# https://github.com/scikit-learn/scikit-learn/issues/27001
872-
scorers, refit_metric = self._get_scorers(convert_multimetric=False)
861+
scorers, refit_metric = self._get_scorers()
873862

874863
X, y = indexable(X, y)
875864
params = _check_method_params(X, params=params)
@@ -1015,7 +1004,10 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None):
10151004
self.feature_names_in_ = self.best_estimator_.feature_names_in_
10161005

10171006
# Store the only scorer not as a dict for single metric evaluation
1018-
self.scorer_ = scorers
1007+
if isinstance(scorers, _MultimetricScorer):
1008+
self.scorer_ = scorers._scorers
1009+
else:
1010+
self.scorer_ = scorers
10191011

10201012
self.cv_results_ = results
10211013
self.n_splits_ = n_splits
@@ -1147,7 +1139,7 @@ def get_metadata_routing(self):
11471139
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
11481140
)
11491141

1150-
scorer, _ = self._get_scorers(convert_multimetric=True)
1142+
scorer, _ = self._get_scorers()
11511143
router.add(
11521144
scorer=scorer,
11531145
method_mapping=MethodMapping()

sklearn/model_selection/_validation.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -358,18 +358,11 @@ def cross_validate(
358358
scorers = check_scoring(estimator, scoring)
359359
else:
360360
scorers = _check_multimetric_scoring(estimator, scoring)
361+
scorers = _MultimetricScorer(
362+
scorers=scorers, raise_exc=(error_score == "raise")
363+
)
361364

362365
if _routing_enabled():
363-
# `cross_validate` will create a `_MultiMetricScorer` if `scoring` is a
364-
# dict at a later stage. We need the same object for the purpose of
365-
# routing. However, creating it here and passing it around would create
366-
# a much larger diff since the dict is used in many places.
367-
if isinstance(scorers, dict):
368-
_scorer = _MultimetricScorer(
369-
scorers=scorers, raise_exc=(error_score == "raise")
370-
)
371-
else:
372-
_scorer = scorers
373366
# For estimators, a MetadataRouter is created in get_metadata_routing
374367
# methods. For these router methods, we create the router to use
375368
# `process_routing` on it.
@@ -386,7 +379,7 @@ def cross_validate(
386379
method_mapping=MethodMapping().add(caller="fit", callee="fit"),
387380
)
388381
.add(
389-
scorer=_scorer,
382+
scorer=scorers,
390383
method_mapping=MethodMapping().add(caller="fit", callee="score"),
391384
)
392385
)
@@ -901,8 +894,8 @@ def _fit_and_score(
901894
if error_score == "raise":
902895
raise
903896
elif isinstance(error_score, numbers.Number):
904-
if isinstance(scorer, dict):
905-
test_scores = {name: error_score for name in scorer}
897+
if isinstance(scorer, _MultimetricScorer):
898+
test_scores = {name: error_score for name in scorer._scorers}
906899
if return_train_score:
907900
train_scores = test_scores.copy()
908901
else:
@@ -966,13 +959,9 @@ def _fit_and_score(
966959
def _score(estimator, X_test, y_test, scorer, score_params, error_score="raise"):
967960
"""Compute the score(s) of an estimator on a given test set.
968961
969-
Will return a dict of floats if `scorer` is a dict, otherwise a single
962+
Will return a dict of floats if `scorer` is a _MultiMetricScorer, otherwise a single
970963
float is returned.
971964
"""
972-
if isinstance(scorer, dict):
973-
# will cache method calls if needed. scorer() returns a dict
974-
scorer = _MultimetricScorer(scorers=scorer, raise_exc=(error_score == "raise"))
975-
976965
score_params = {} if score_params is None else score_params
977966

978967
try:

sklearn/model_selection/tests/test_validation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
precision_score,
4444
r2_score,
4545
)
46+
from sklearn.metrics._scorer import _MultimetricScorer
4647
from sklearn.model_selection import (
4748
GridSearchCV,
4849
GroupKFold,
@@ -2323,7 +2324,9 @@ def three_params_scorer(i, j, k):
23232324
),
23242325
(
23252326
True,
2326-
{"sc1": three_params_scorer, "sc2": three_params_scorer},
2327+
_MultimetricScorer(
2328+
scorers={"sc1": three_params_scorer, "sc2": three_params_scorer}
2329+
),
23272330
3,
23282331
(1, 3),
23292332
(0, 1),
@@ -2332,7 +2335,9 @@ def three_params_scorer(i, j, k):
23322335
),
23332336
(
23342337
False,
2335-
{"sc1": three_params_scorer, "sc2": three_params_scorer},
2338+
_MultimetricScorer(
2339+
scorers={"sc1": three_params_scorer, "sc2": three_params_scorer}
2340+
),
23362341
10,
23372342
(1, 3),
23382343
(0, 1),

0 commit comments

Comments
 (0)
0