@@ -186,10 +186,11 @@ def check_scoring_validator_for_single_metric_usecases(scoring_validator):
186
186
187
187
188
188
def check_multimetric_scoring_single_metric_wrapper (* args , ** kwargs ):
189
- # This wraps the _check_multimetric_scoring to take in single metric
190
- # scoring parameter so we can run the tests that we will run for
191
- # check_scoring, for check_multimetric_scoring too for single-metric
192
- # usecases
189
+ # This wraps the _check_multimetric_scoring to take in
190
+ # single metric scoring parameter so we can run the tests
191
+ # that we will run for check_scoring, for check_multimetric_scoring
192
+ # too for single-metric usecases
193
+
193
194
scorers , is_multi = _check_multimetric_scoring (* args , ** kwargs )
194
195
# For all single metric use cases, it should register as not multimetric
195
196
assert_false (is_multi )
@@ -370,7 +371,21 @@ def test_thresholded_scorers():
370
371
X , y = make_blobs (random_state = 0 , centers = 3 )
371
372
X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 0 )
372
373
clf .fit (X_train , y_train )
373
- assert_raises (ValueError , get_scorer ('roc_auc' ), clf , X_test , y_test )
374
+ with pytest .raises (ValueError , match = "multiclass format is not supported" ):
375
+ get_scorer ('roc_auc' )(clf , X_test , y_test )
376
+
377
+ # test error is raised with a single class present in model
378
+ # (predict_proba shape is not suitable for binary auc)
379
+ X , y = make_blobs (random_state = 0 , centers = 2 )
380
+ X_train , X_test , y_train , y_test = train_test_split (X , y , random_state = 0 )
381
+ clf = DecisionTreeClassifier ()
382
+ clf .fit (X_train , np .zeros_like (y_train ))
383
+ with pytest .raises (ValueError , match = "need classifier with two classes" ):
384
+ get_scorer ('roc_auc' )(clf , X_test , y_test )
385
+
386
+ # for proba scorers
5303
387
+ with pytest .raises (ValueError , match = "need classifier with two classes" ):
388
+ get_scorer ('neg_log_loss' )(clf , X_test , y_test )
374
389
375
390
376
391
def test_thresholded_scorers_multilabel_indicator_data ():
0 commit comments