8000 [MRG+1] Raise error when SparseSeries is passed into classification m… · scikit-learn/scikit-learn@3700b96 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3700b96

Browse files
nielsenmarkus11jnothman
authored andcommitted
[MRG+1] Raise error when SparseSeries is passed into classification metrics (#7373)
* Raise error when SparseSeries is passed into roc_curve * Changed "y_true" in second if block to "y_score" * Remove code to import pandas and add sparseseries check to 'type_of_target' function. Finally, add 'type_of_target' call to _binary_clf_curve * Remove pandas import and old comparison in roc_curve. * Add test for 'type_of_target' function * Add white space after commas * Correct other white space issues * Move type_of_target test into try clause, remove test_precision_recall_curve_pos_label since as multiclass it doesn't make sense * Add test_precision_recall_curve_pos_label back in and also add test_binary_clf_curve to test new logic in _binary_clf_curve function * Correct syntax and formatting. * Remove trailing white space * Correct validation logic * Update test_multiclass.py per @jnothman 's request. * Import SkipTest function. * Remove extra white space from line 303
1 parent ef2408c commit 3700b96

File tree

4 files changed

+28
-1
lines changed

4 files changed

+28
-1
lines changed

sklearn/metrics/ranking.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,13 @@ def _binary_clf_curve(y_true, y_score, pos_label=None, sample_weight=None):
311311
thresholds : array, shape = [n_thresholds]
312312
Decreasing score values.
313313
"""
314-
check_consistent_length(y_true, y_score)
314+
# Check to make sure y_true is valid
315+
y_type = type_of_target(y_true)
316+
if not (y_type == "binary" or
317+
(y_type == "multiclass" and pos_label is not None)):
318+
raise ValueError("{0} format is not supported".format(y_type))
319+
320+
check_consistent_length(y_true, y_score, sample_weight)
315321
y_true = column_or_1d(y_true)
316322
y_score = column_or_1d(y_score)
317323
assert_all_finite(y_true)

sklearn/metrics/tests/test_ranking.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,14 @@ def test_auc_score_non_binary_class():
457457
roc_auc_score, y_true, y_pred)
458458

459459

460+
def test_binary_clf_curve():
461+
rng = check_random_state(404)
462+
y_true = rng.randint(0, 3, size=10)
463+
y_pred = rng.rand(10)
464+
msg = "multiclass format is not supported"
465+
assert_raise_message(ValueError, msg, precision_recall_curve,
466+
y_true, y_pred)
467+
460468
def test_precision_recall_curve():
461469
y_true, _, probas_pred = make_prediction(binary=True)
462470
_test_precision_recall_curve(y_true, probas_pred)

sklearn/utils/multiclass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,10 @@ def type_of_target(y):
243243
raise ValueError('Expected array-like (array or non-string sequence), '
244244
'got %r' % y)
245245

246+
sparseseries = (y.__class__.__name__ == 'SparseSeries')
247+
if sparseseries:
248+
raise ValueError("y cannot be class 'SparseSeries'.")
249+
246250
if is_multilabel(y):
247251
return 'multilabel-indicator'
248252

sklearn/utils/tests/test_multiclass.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sklearn.utils.testing import assert_false
2222
from sklearn.utils.testing import assert_raises
2323
from sklearn.utils.testing import assert_raises_regex
24+
from sklearn.utils.testing import SkipTest
2425

2526
from sklearn.utils.multiclass import unique_labels
2627
from sklearn.utils.multiclass import is_multilabel
@@ -295,6 +296,14 @@ def test_type_of_target():
295296
' use a binary array or sparse matrix instead.')
296297
assert_raises_regex(ValueError, msg, type_of_target, example)
297298

299+
try:
300+
from pandas import SparseSeries
301+
except ImportError:
302+
raise SkipTest("Pandas not found")
303+
304+
y = SparseSeries([1, 0, 0, 1, 0])
305+
msg = "y cannot be class 'SparseSeries'."
306+
assert_raises_regex(ValueError, msg, type_of_target, y)
298307

299308
def test_class_distribution():
300309
y = np.array([[1, 0, 0, 1],

0 commit comments

Comments
 (0)
0