8000 ENH Improved error message for bad predict_proba shape in ThresholdSc… · jnothman/scikit-learn@1f4451e · GitHub
Skip to content

Commit 1f4451e

Browse files
Reshama Shaikhjnothman
authored andcommitted
ENH Improved error message for bad predict_proba shape in ThresholdScorer (scikit-learn#12486)
Continues and resolves scikit-learn#12221, fixes scikit-learn#7598
1 parent 0ec901f commit 1f4451e

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

sklearn/metrics/scorer.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,13 @@ def __call__(self, clf, X, y, sample_weight=None):
126126
y_type = type_of_target(y)
127127
y_pred = clf.predict_proba(X)
128128
if y_type == "binary":
129-
y_pred = y_pred[:, 1]
129+
if y_pred.shape[1] == 2:
130+
y_pred = y_pred[:, 1]
131+
else:
132+
raise ValueError('got predict_proba of shape {},'
133+
' but need classifier with two'
134+
' classes for {} scoring'.format(
135+
y_pred.shape, self._score_func.__name__))
130136
if sample_weight is not None:
131137
return self._sign * self._score_func(y, y_pred,
132138
sample_weight=sample_weight,
@@ -183,7 +189,14 @@ def __call__(self, clf, X, y, sample_weight=None):
183189
y_pred = clf.predict_proba(X)
184190

185191
if y_type == "binary":
186-
y_pred = y_pred[:, 1]
192+
if y_pred.shape[1] == 2:
193+
y_pred = y_pred[:, 1]
194+
else:
195+
raise ValueError('got predict_proba of shape {},'
196+
' but need classifier with two'
197+
' classes for {} scoring'.format(
198+
y_pred.shape,
199+
self._score_func.__name__))
187200
elif isinstance(y_pred, list):
188201
y_pred = np.vstack([p[:, -1] for p in y_pred]).T
189202

sklearn/metrics/tests/test_score_objects.py

Lines changed: 20 additions & 5 deletions
5303
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,11 @@ def check_scoring_validator_for_single_metric_usecases(scoring_validator):
186186

187187

188188
def check_multimetric_scoring_single_metric_wrapper(*args, **kwargs):
189-
# This wraps the _check_multimetric_scoring to take in single metric
190-
# scoring parameter so we can run the tests that we will run for
191-
# check_scoring, for check_multimetric_scoring too for single-metric
192-
# usecases
189+
# This wraps the _check_multimetric_scoring to take in
190+
# single metric scoring parameter so we can run the tests
191+
# that we will run for check_scoring, for check_multimetric_scoring
192+
# too for single-metric usecases
193+
193194
scorers, is_multi = _check_multimetric_scoring(*args, **kwargs)
194195
# For all single metric use cases, it should register as not multimetric
195196
assert_false(is_multi)
@@ -370,7 +371,21 @@ def test_thresholded_scorers():
370371
X, y = make_blobs(random_state=0, centers=3)
371372
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
372373
clf.fit(X_train, y_train)
373-
assert_raises(ValueError, get_scorer('roc_auc'), clf, X_test, y_test)
374+
with pytest.raises(ValueError, match="multiclass format is not supported"):
375+
get_scorer('roc_auc')(clf, X_test, y_test)
376+
377+
# test error is raised with a single class present in model
378+
# (predict_proba shape is not suitable for binary auc)
379+
X, y = make_blobs(random_state=0, centers=2)
380+
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
381+
clf = DecisionTreeClassifier()
382+
clf.fit(X_train, np.zeros_like(y_train))
383+
with pytest.raises(ValueError, match="need classifier with two classes"):
384+
get_scorer('roc_auc')(clf, X_test, y_test)
385+
386+
# for proba scorers
387+
with pytest.raises(ValueError, match="need classifier with two classes"):
388+
get_scorer('neg_log_loss')(clf, X_test, y_test)
374389

375390

376391
def test_thresholded_scorers_multilabel_indicator_data():

0 commit comments

Comments
 (0)
0