10000 [MRG+1] Improve the error message for some metrics when the shape of … · jwjohnson314/scikit-learn@28bd2cf · GitHub
[go: up one dir, main page]

Skip to content

Commit 28bd2cf

Browse files
qinhanmin2014Jeremiah Johnson
authored andcommitted
[MRG+1] Improve the error message for some metrics when the shape of sample_weight is inappropriate (scikit-learn#9903)
1 parent 4d35478 commit 28bd2cf

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

sklearn/metrics/classification.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
174174

175175
# Compute accuracy for each possible representation
176176
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
177+
check_consistent_length(y_true, y_pred, sample_weight)
177178
if y_type.startswith('multilabel'):
178179
differing_labels = count_nonzero(y_true - y_pred, axis=1)
179180
score = differing_labels == 0
@@ -337,7 +338,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
337338
else:
338339
sample_weight = np.asarray(sample_weight)
339340

340-
check_consistent_length(sample_weight, y_true, y_pred)
341+
check_consistent_length(y_true, y_pred, sample_weight)
341342

342343
n_labels = labels.size
343344
label_to_ind = dict((y, x) for x, y in enumerate(labels))
@@ -518,6 +519,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,
518519

519520
# Compute accuracy for each possible representation
520521
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
522+
check_consistent_length(y_true, y_pred, sample_weight)
521523
if y_type.startswith('multilabel'):
522524
with np.errstate(divide='ignore', invalid='ignore'):
523525
# oddly, we may get an "invalid" rather than a "divide" error here
@@ -593,6 +595,7 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None):
593595
-0.33...
594596
"""
595597
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
598+
check_consistent_length(y_true, y_pred, sample_weight)
596599
if y_type not in {"binary", "multiclass"}:
597600
raise ValueError("%s is not supported" % y_type)
598601

@@ -1097,6 +1100,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
10971100
raise ValueError("beta should be >0 in the F-beta score")
10981101

10991102
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1103+
check_consistent_length(y_true, y_pred, sample_weight)
11001104
present_labels = unique_labels(y_true, y_pred)
11011105

11021106
if average == 'binary':
@@ -1624,6 +1628,7 @@ def hamming_loss(y_true, y_pred, labels=None, sample_weight=None,
16241628
labels = classes
16251629

16261630
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1631+
check_consistent_length(y_true, y_pred, sample_weight)
16271632

16281633
if labels is None:
16291634
labels = unique_labels(y_true, y_pred)
@@ -1712,7 +1717,7 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None,
17121717
The logarithm used is the natural logarithm (base-e).
17131718
"""
17141719
y_pred = check_array(y_pred, ensure_2d=False)
1715-
check_consistent_length(y_pred, y_true)
1720+
check_consistent_length(y_pred, y_true, sample_weight)
17161721

17171722
lb = LabelBinarizer()
17181723

@@ -1985,6 +1990,7 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None):
19851990
y_prob = column_or_1d(y_prob)
19861991
assert_all_finite(y_true)
19871992
assert_all_finite(y_prob)
1993+
check_consistent_length(y_true, y_prob, sample_weight)
19881994

19891995
if pos_label is None:
19901996
pos_label = y_true.max()

sklearn/metrics/regression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def mean_absolute_error(y_true, y_pred,
168168
"""
169169
y_type, y_true, y_pred, multioutput = _check_reg_targets(
170170
y_true, y_pred, multioutput)
171+
check_consistent_length(y_true, y_pred, sample_weight)
171172
output_errors = np.average(np.abs(y_pred - y_true),
172173
weights=sample_weight, axis=0)
173174
if isinstance(multioutput, string_types):
@@ -236,6 +237,7 @@ def mean_squared_error(y_true, y_pred,
236237
"""
237238
y_type, y_true, y_pred, multioutput = _check_reg_targets(
238239
y_true, y_pred, multioutput)
240+
check_consistent_length(y_true, y_pred, sample_weight)
239241
output_errors = np.average((y_true - y_pred) ** 2, axis=0,
240242
weights=sample_weight)
241243
if isinstance(multioutput, string_types):
@@ -306,6 +308,7 @@ def mean_squared_log_error(y_true, y_pred,
306308
"""
307309
y_type, y_true, y_pred, multioutput = _check_reg_targets(
308310
y_true, y_pred, multioutput)
311+
check_consistent_length(y_true, y_pred, sample_weight)
309312

310313
if not (y_true >= 0).all() and not (y_pred >= 0).all():
311314
raise ValueError("Mean Squared Logarithmic Error cannot be used when "
@@ -409,6 +412,7 @@ def explained_variance_score(y_true, y_pred,
409412
"""
410413
y_type, y_true, y_pred, multioutput = _check_reg_targets(
411414
y_true, y_pred, multioutput)
415+
check_consistent_length(y_true, y_pred, sample_weight)
412416

413417
y_diff_avg = np.average(y_true - y_pred, weights=sample_weight, axis=0)
414418
numerator = np.average((y_true - y_pred - y_diff_avg) ** 2,
@@ -528,6 +532,7 @@ def r2_score(y_true, y_pred, sample_weight=None,
528532
"""
529533
y_type, y_true, y_pred, multioutput = _check_reg_targets(
530534
y_true, y_pred, multioutput)
535+
check_consistent_length(y_true, y_pred, sample_weight)
531536

532537
if sample_weight is not None:
533538
sample_weight = column_or_1d(sample_weight)

sklearn/metrics/tests/test_common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.datasets import make_multilabel_classification
1010
from sklearn.preprocessing import LabelBinarizer
1111
from sklearn.utils.multiclass import type_of_target
12+
from sklearn.utils.validation import _num_samples
1213
from sklearn.utils.validation import check_random_state
1314
from sklearn.utils import shuffle
1415

@@ -1005,10 +1006,15 @@ def check_sample_weight_invariance(name, metric, y1, y2):
10051006
err_msg="%s sample_weight is not invariant "
10061007
"under scaling" % name)
10071008

1008-
# Check that if sample_weight.shape[0] != y_true.shape[0], it raised an
1009-
# error
1010-
assert_raises(Exception, metric, y1, y2,
1011-
sample_weight=np.hstack([sample_weight, sample_weight]))
1009+
# Check that if number of samples in y_true and sample_weight are not
1010+
# equal, meaningful error is raised.
1011+
error_message = ("Found input variables with inconsistent numbers of "
1012+
"samples: [{}, {}, {}]".format(
1013+
_num_samples(y1), _num_samples(y2),
1014+
_num_samples(sample_weight) * 2))
1015+
assert_raise_message(ValueError, error_message, metric, y1, y2,
1016+
sample_weight=np.hstack([sample_weight,
1017+
sample_weight]))
10121018

10131019

10141020
def test_sample_weight_invariance(n_samples=50):

0 commit comments

Comments
 (0)
0