10000 ENH Adds _MultimetricScorer for Optimized Scoring (#14593) · scikit-learn/scikit-learn@fbb2c7c · GitHub
[go: up one dir, main page]

Skip to content

Commit fbb2c7c

Browse files
thomasjpfanjnothman
authored andcommitted
ENH Adds _MultimetricScorer for Optimized Scoring (#14593)
1 parent 66b0f5f commit fbb2c7c

File tree

4 files changed

+278
-72
lines changed

4 files changed

+278
-72
lines changed

doc/whats_new/v0.22.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,11 @@ Changelog
348348
- |Enhancement| :class:`model_selection.RandomizedSearchCV` now accepts lists
349349
of parameter distributions. :pr:`14549` by `Andreas Müller`_.
350350

351+
- |Efficiency| Improved performance of multimetric scoring in
352+
:func:`model_selection.cross_validate`,
353+
:class:`model_selection.GridSearchCV`, and
354+
:class:`model_selection.RandomizedSearchCV`. :pr:`14593` by `Thomas Fan`_.
355+
351356
- |Fix| Reimplemented :class:`model_selection.StratifiedKFold` to fix an issue
352357
where one test set could be `n_classes` larger than another. Test sets should
353358
now be near-equally sized. :pr:`14704` by `Joel Nothman`_.

sklearn/metrics/scorer.py

Lines changed: 124 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@
1818
# Arnaud Joly <arnaud.v.joly@gmail.com>
1919
# License: Simplified BSD
2020

21-
from abc import ABCMeta
2221
from collections.abc import Iterable
22+
from functools import partial
23+
from collections import Counter
2324

2425
import numpy as np
2526

@@ -44,7 +45,82 @@
4445
from ..base import is_regressor
4546

4647

47-
class _BaseScorer(metaclass=ABCMeta):
48+
def _cached_call(cache, estimator, method, *args, **kwargs):
49+
"""Call estimator with method and args and kwargs."""
50+
if cache is None:
51+
return getattr(estimator, method)(*args, **kwargs)
52+
53+
try:
54+
return cache[method]
55+
except KeyError:
56+
result = getattr(estimator, method)(*args, **kwargs)
57+
cache[method] = result
58+
return result
59+
60+
61+
class _MultimetricScorer:
62+
"""Callable for multimetric scoring used to avoid repeated calls
63+
to `predict_proba`, `predict`, and `decision_function`.
64+
65+
`_MultimetricScorer` will return a dictionary of scores corresponding to
66+
the scorers in the dictionary. Note that `_MultimetricScorer` can be
67+
created with a dictionary with one key (i.e. only one actual scorer).
68+
69+
Parameters
70+
----------
71+
scorers : dict
72+
Dictionary mapping names to callable scorers.
73+
"""
74+
def __init__(self, **scorers):
75+
self._scorers = scorers
76+
77+
def __call__(self, estimator, *args, **kwargs):
78+
"""Evaluate predicted target values."""
79+
scores = {}
80+
cache = {} if self._use_cache(estimator) else None
81+
cached_call = partial(_cached_call, cache)
82+
83+
for name, scorer in self._scorers.items():
84+
if isinstance(scorer, _BaseScorer):
85+
score = scorer._score(cached_call, estimator,
86+
*args, **kwargs)
87+
else:
88+
score = scorer(estimator, *args, **kwargs)
89+
scores[name] = score
90+
return scores
91+
92+
def _use_cache(self, estimator):
93+
"""Return True if using a cache is beneficial.
94+
95+
Caching may be beneficial when one of these conditions holds:
96+
- `_ProbaScorer` will be called twice.
97+
- `_PredictScorer` will be called twice.
98+
- `_ThresholdScorer` will be called twice.
99+
- `_ThresholdScorer` and `_PredictScorer` are called and
100+
estimator is a regressor.
101+
- `_ThresholdScorer` and `_ProbaScorer` are called and
102+
estimator does not have a `decision_function` attribute.
103+
104+
"""
105+
if len(self._scorers) == 1: # Only one scorer
106+
return False
107+
108+
counter = Counter([type(v) for v in self._scorers.values()])
109+
110+
if any(counter[known_type] > 1 for known_type in
111+
[_PredictScorer, _ProbaScorer, _ThresholdScorer]):
112+
return True
113+
114+
if counter[_ThresholdScorer]:
115+
if is_regressor(estimator) and counter[_PredictScorer]:
116+
return True
117+
elif (counter[_ProbaScorer] and
118+
not hasattr(estimator, "decision_function")):
119+
return True
120+
return False
121+
122+
123+
class _BaseScorer:
48124
def __init__(self, score_func, sign, kwargs):
49125
self._kwargs = kwargs
50126
self._score_func = score_func
@@ -58,17 +134,47 @@ def __repr__(self):
58134
"" if self._sign > 0 else ", greater_is_better=False",
59135
self._factory_args(), kwargs_string))
60136

137+
def __call__(self, estimator, X, y_true, sample_weight=None):
138+
"""Evaluate predicted target values for X relative to y_true.
139+
140+
Parameters
141+
----------
142+
estimator : object
143+
Trained estimator to use for scoring. Must have a predict_proba
144+
method; the output of that is used to compute the score.
145+
146+
X : array-like or sparse matrix
147+
Test data that will be fed to estimator.predict.
148+
149+
y_true : array-like
150+
Gold standard target values for X.
151+
152+
sample_weight : array-like, optional (default=None)
153+
Sample weights.
154+
155+
Returns
156+
-------
157+
score : float
158+
Score function applied to prediction of estimator on X.
159+
"""
160+
return self._score(partial(_cached_call, None), estimator, X, y_true,
161+
sample_weight=sample_weight)
162+
61163
def _factory_args(self):
62164
"""Return non-default make_scorer arguments for repr."""
63165
return ""
64166

65167

66168
class _PredictScorer(_BaseScorer):
67-
def __call__(self, estimator, X, y_true, sample_weight=None):
169+
def _score(self, method_caller, estimator, X, y_true, sample_weight=None):
68170
"""Evaluate predicted target values for X relative to y_true.
69171
70172
Parameters
71173
----------
174+
method_caller : callable
175+
Returns predictions given an estimator, method name, and other
176+
arguments, potentially caching results.
177+
72178
estimator : object
73179
Trained estimator to use for scoring. Must have a predict_proba
74180
method; the output of that is used to compute the score.
@@ -87,8 +193,7 @@ def __call__(self, estimator, X, y_true, sample_weight=None):
87193
score : float
88194
Score function applied to prediction of estimator on X.
89195
"""
90-
91-
y_pred = estimator.predict(X)
196+
y_pred = method_caller(estimator, "predict", X)
92197
if sample_weight is not None:
93198
return self._sign * self._score_func(y_true, y_pred,
94199
sample_weight=sample_weight,
@@ -99,11 +204,15 @@ def __call__(self, estimator, X, y_true, sample_weight=None):
99204

100205

101206
class _ProbaScorer(_BaseScorer):
102-
def __call__(self, clf, X, y, sample_weight=None):
207+
def _score(self, method_caller, clf, X, y, sample_weight=None):
103208
"""Evaluate predicted probabilities for X relative to y_true.
104209
105210
Parameters
106211
----------
212+
method_caller : callable
213+
Returns predictions given an estimator, method name, and other
214+
arguments, potentially caching results.
215+
107216
clf : object
108217
Trained classifier to use for scoring. Must have a predict_proba
109218
method; the output of that is used to compute the score.
@@ -124,7 +233,7 @@ def __call__(self, clf, X, y, sample_weight=None):
124233
Score function applied to prediction of estimator on X.
125234
"""
126235
y_type = type_of_target(y)
127-
y_pred = clf.predict_proba(X)
236+
y_pred = method_caller(clf, "predict_proba", X)
128237
if y_type == "binary":
129238
if y_pred.shape[1] == 2:
130239
y_pred = y_pred[:, 1]
@@ -145,11 +254,15 @@ def _factory_args(self):
145254

146255

147256
class _ThresholdScorer(_BaseScorer):
148-
def __call__(self, clf, X, y, sample_weight=None):
257+
def _score(self, method_caller, clf, X, y, sample_weight=None):
149258
"""Evaluate decision function output for X relative to y_true.
150259
151260
Parameters
152261
----------
262+
method_caller : callable
263+
Returns predictions given an estimator, method name, and other
264+
arguments, potentially caching results.
265+
153266
clf : object
154267
Trained classifier to use for scoring. Must have either a
155268
decision_function method or a predict_proba method; the output of
@@ -176,17 +289,17 @@ def __call__(self, clf, X, y, sample_weight=None):
176289
raise ValueError("{0} format is not supported".format(y_type))
177290

178291
if is_regressor(clf):
179-
y_pred = clf.predict(X)
292+
y_pred = method_caller(clf, "predict", X)
180293
else:
181294
try:
182-
y_pred = clf.decision_function(X)
295+
y_pred = method_caller(clf, "decision_function", X)
183296

184297
# For multi-output multi-class estimator
185298
if isinstance(y_pred, list):
186299
y_pred = np.vstack([p for p in y_pred]).T
187300

188301
except (NotImplementedError, AttributeError):
189-
y_pred = clf.predict_proba(X)
302+
y_pred = method_caller(clf, "predict_proba", X)
190303

191304
if y_type == "binary":
192305
if y_pred.shape[1] == 2:

sklearn/metrics/tests/test_score_objects.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33
import shutil
44
import os
55
import numbers
6+
from unittest.mock import Mock
67

78
import numpy as np
89
import pytest
910
import joblib
1011

12+
from numpy.testing import assert_allclose
1113
from sklearn.utils.testing import assert_almost_equal
1214
from sklearn.utils.testing import assert_array_equal
1315
from sklearn.utils.testing import ignore_warnings
@@ -18,10 +20,11 @@
1820
jaccard_score)
1921
from sklearn.metrics import cluster as cluster_module
2022
from sklearn.metrics.scorer import (check_scoring, _PredictScorer,
21-
_passthrough_scorer)
23+
_passthrough_scorer, _MultimetricScorer)
2224
from sklearn.metrics import accuracy_score
2325
from sklearn.metrics.scorer import _check_multimetric_scoring
2426
from sklearn.metrics import make_scorer, get_scorer, SCORERS
27+
from sklearn.neighbors import KNeighborsClassifier
2528
from sklearn.svm import LinearSVC
2629
from sklearn.pipeline import make_pipeline
2730
from sklearn.cluster import KMeans
@@ -546,3 +549,112 @@ def test_scoring_is_not_metric():
546549
check_scoring(Ridge(), r2_score)
547550
with pytest.raises(ValueError, match='make_scorer'):
548551
check_scoring(KMeans(), cluster_module.adjusted_rand_score)
552+
553+
554+
@pytest.mark.parametrize(
555+
("scorers,expected_predict_count,"
556+
"expected_predict_proba_count,expected_decision_func_count"),
557+
[({'a1': 'accuracy', 'a2': 'accuracy',
558+
'll1': 'neg_log_loss', 'll2': 'neg_log_loss',
559+
'ra1': 'roc_auc', 'ra2': 'roc_auc'}, 1, 1, 1),
560+
(['roc_auc', 'accuracy'], 1, 0, 1),
561+
(['neg_log_loss', 'accuracy'], 1, 1, 0)])
562+
def test_multimetric_scorer_calls_method_once(scorers, expected_predict_count,
563+
expected_predict_proba_count,
564+
expected_decision_func_count):
565+
X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0])
566+
567+
mock_est = Mock()
568+
fit_func = Mock(return_value=mock_est)
569+
predict_func = Mock(return_value=y)
570+
571+
pos_proba = np.random.rand(X.shape[0])
572+
proba = np.c_[1 - pos_proba, pos_proba]
573+
predict_proba_func = Mock(return_value=proba)
574+
decision_function_func = Mock(return_value=pos_proba)
575+
576+
mock_est.fit = fit_func
577+
mock_est.predict = predict_func
578+
mock_est.predict_proba = predict_proba_func
579+
mock_est.decision_function = decision_function_func
580+
581+
scorer_dict, _ = _check_multimetric_scoring(LogisticRegression(), scorers)
582+
multi_scorer = _MultimetricScorer(**scorer_dict)
583+
results = multi_scorer(mock_est, X, y)
584+
585+
assert set(scorers) == set(results) # compare dict keys
586+
587+
assert predict_func.call_count == expected_predict_count
588+
assert predict_proba_func.call_count == expected_predict_proba_count
589+
assert decision_function_func.call_count == expected_decision_func_count
590+
591+
592+
def test_multimetric_scorer_calls_method_once_classifier_no_decision():
593+
predict_proba_call_cnt = 0
594+
595+
class MockKNeighborsClassifier(KNeighborsClassifier):
596+
def predict_proba(self, X):
597+
nonlocal predict_proba_call_cnt
598+
predict_proba_call_cnt += 1
599+
return super().predict_proba(X)
600+
601+
X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0])
602+
603+
# no decision function
604+
clf = MockKNeighborsClassifier(n_neighbors=1)
605+
clf.fit(X, y)
606+
607+
scorers = ['roc_auc', 'neg_log_loss']
608+
scorer_dict, _ = _check_multimetric_scoring(clf, scorers)
609+
scorer = _MultimetricScorer(**scorer_dict)
610+
scorer(clf, X, y)
611+
612+
assert predict_proba_call_cnt == 1
613+
614+
615+
def test_multimetric_scorer_calls_method_once_regressor_threshold():
616+
predict_called_cnt = 0
617+
618+
class MockDecisionTreeRegressor(DecisionTreeRegressor):
619+
def predict(self, X):
620+
nonlocal predict_called_cnt
621+
predict_called_cnt += 1
622+
return super().predict(X)
623+
624+
X, y = np.array([[1], [1], [0], [0], [0]]), np.array([0, 1, 1, 1, 0])
625+
626+
# no decision function
627+
clf = MockDecisionTreeRegressor()
628+
clf.fit(X, y)
629+
630+
scorers = {'neg_mse': 'neg_mean_squared_error', 'r2': 'roc_auc'}
631+
scorer_dict, _ = _check_multimetric_scoring(clf, scorers)
632+
scorer = _MultimetricScorer(**scorer_dict)
633+
scorer(clf, X, y)
634+
635+
assert predict_called_cnt == 1
636+
637+
638+
def test_multimetric_scorer_sanity_check():
639+
# scoring dictionary returned is the same as calling each scorer seperately
640+
scorers = {'a1': 'accuracy', 'a2': 'accuracy',
641+
'll1': 'neg_log_loss', 'll2': 'neg_log_loss',
642+
'ra1': 'roc_auc', 'ra2': 'roc_auc'}
643+
644+
X, y = make_classification(random_state=0)
645+
646+
clf = DecisionTreeClassifier()
647+
clf.fit(X, y)
648+
649+
scorer_dict, _ = _check_multimetric_scoring(clf, scorers)
650+
multi_scorer = _MultimetricScorer(**scorer_dict)
651+
652+
result = multi_scorer(clf, X, y)
653+
654+
seperate_scores = {
655+
name: get_scorer(name)(clf, X, y)
656+
for name in ['accuracy', 'neg_log_loss', 'roc_auc']}
657+
658+
for key, value in result.items():
659+
score_name = scorers[key]
660+
assert_allclose(value, seperate_scores[score_name])

0 commit comments

Comments
 (0)
0