8000 [MRG+1] Improve the error message for some metrics when the shape of … · scikit-learn/scikit-learn@6e75058 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6e75058

Browse files
qinhanmin2014TomDLT
authored andcommitted
[MRG+1] Improve the error message for some metrics when the shape of sample_weight is inappropriate (#9903)
1 parent 6b130f1 commit 6e75058

File tree

3 files changed

+23
-6
lines changed

3 files changed

+23
-6
lines changed

sklearn/metrics/classification.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def accuracy_score(y_true, y_pred, normalize=True, sample_weight=None):
174174

175175
# Compute accuracy for each possible representation
176176
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
177+
check_consistent_length(y_true, y_pred, sample_weight)
177178
if y_type.startswith('multilabel'):
178179
differing_labels = count_nonzero(y_true - y_pred, axis=1)
179180
score = differing_labels == 0
@@ -263,7 +264,7 @@ def confusion_matrix(y_true, y_pred, labels=None, sample_weight=None):
263264
else:
264265
sample_weight = np.asarray(sample_weight)
265266

266-
check_consistent_length(sample_weight, y_true, y_pred)
267+
check_consistent_length(y_true, y_pred, sample_weight)
267268

268269
n_labels = labels.size
269270
label_to_ind = dict((y, x) for x, y in enumerate(labels))
@@ -444,6 +445,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,
444445

445446
# Compute accuracy for each possible representation
446447
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
448+
check_consistent_length(y_true, y_pred, sample_weight)
447449
if y_type.startswith('multilabel'):
448450
with np.errstate(divide='ignore', invalid='ignore'):
449451
# oddly, we may get an "invalid" rather than a "divide" error here
@@ -519,6 +521,7 @@ def matthews_corrcoef(y_true, y_pred, sample_weight=None):
519521
-0.33...
520522
"""
521523
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
524+
check_consistent_length(y_true, y_pred, sample_weight)
522525
if y_type not in {"binary", "multiclass"}:
523526
raise ValueError("%s is not supported" % y_type)
524527

@@ -1023,6 +1026,7 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
10231026
raise ValueError("beta should be >0 in the F-beta score")
10241027

10251028
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1029+
check_consistent_length(y_true, y_pred, sample_weight)
10261030
present_labels = unique_labels(y_true, y_pred)
10271031

10281032
if average == 'binary':
@@ -1550,6 +1554,7 @@ def hamming_loss(y_true, y_pred, labels=None, sample_weight=None,
15501554
labels = classes
15511555

15521556
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
1557+
check_consistent_length(y_true, y_pred, sample_weight)
15531558

15541559
if labels is None:
15551560
labels = unique_labels(y_true, y_pred)
@@ -1638,7 +1643,7 @@ def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None,
16381643
The logarithm used is the natural logarithm (base-e).
16391644
"""
16401645
y_pred = check_array(y_pred, ensure_2d=False)
1641-
check_consistent_length(y_pred, y_true)
1646+
check_consistent_length(y_pred, y_true, sample_weight)
16421647

16431648
lb = LabelBinarizer()
16441649

@@ -1911,6 +1916,7 @@ def brier_score_loss(y_true, y_prob, sample_weight=None, pos_label=None):
19111916
y_prob = column_or_1d(y_prob)
19121917
assert_all_finite(y_true)
19131918
assert_all_finite(y_prob)
1919+
check_consistent_length(y_true, y_prob, sample_weight)
19141920

19151921
if pos_label is None:
19161922
pos_label = y_true.max()

sklearn/metrics/regression.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def mean_absolute_error(y_true, y_pred,
168168
"""
169169
y_type, y_true, y_pred, multioutput = _check_reg_targets(
170170
y_true, y_pred, multioutput)
171+
check_consistent_length(y_true, y_pred, sample_weight)
171172
output_errors = np.average(np.abs(y_pred - y_true),
172173
weights=sample_weight, axis=0)
173174
if isinstance(multioutput, string_types):
@@ -236,6 +237,7 @@ def mean_squared_error(y_true, y_pred,
236237
"""
237238
y_type, y_true, y_pred, multioutput = _check_reg_targets(
238239
y_true, y_pred, multioutput)
240+
check_consistent_length(y_true, y_pred, sample_weight)
239241
output_errors = np.average((y_true - y_pred) ** 2, axis=0,
240242
weights=sample_weight)
241243
if isinstance(multioutput, string_types):
@@ -306,6 +308,7 @@ def mean_squared_log_error(y_true, y_pred,
306308
"""
307309
y_type, y_true, y_pred, multioutput = _check_reg_targets(
308310
y_true, y_pred, multioutput)
311+
check_consistent_length(y_true, y_pred, sample_weight)
309312

310313
if not (y_true >= 0).all() and not (y_pred >= 0).all():
311314
raise ValueError("Mean Squared Logarithmic Error cannot be used when "
@@ -409,6 +412,7 @@ def explained_variance_score(y_true, y_pred,
409412
"""
410413
y_type, y_true, y_pred, multioutput = _check_reg_targets(
411414
y_true, y_pred, multioutput)
415+
check_consistent_length(y_true, y_pred, sample_weight)
412416

413417
y_diff_avg = np.average(y_true - y_pred, weights=sample_weight, axis=0)
414418
numerator = np.average((y_true - y_pred - y_diff_avg) ** 2,
@@ -528,6 +532,7 @@ def r2_score(y_true, y_pred, sample_weight=None,
528532
"""
529533
y_type, y_true, y_pred, multioutput = _check_reg_targets(
530534
y_true, y_pred, multioutput)
535+
check_consistent_length(y_true, y_pred, sample_weight)
531536

532537
if sample_weight is not None:
533538
sample_weight = column_or_1d(sample_weight)

sklearn/metrics/tests/test_common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.datasets import make_multilabel_classification
1010
from sklearn.preprocessing import LabelBinarizer
1111
from sklearn.utils.multiclass import type_of_target
12+
from sklearn.utils.validation import _num_samples
1213
from sklearn.utils.validation import check_random_state
1314
from sklearn.utils import shuffle
1415

@@ -1005,10 +1006,15 @@ def check_sample_weight_invariance(name, metric, y1, y2):
10051006
err_msg="%s sample_weight is not invariant "
10061007
"under scaling" % name)
10071008

1008-
# Check that if sample_weight.shape[0] != y_true.shape[0], it raised an
1009-
# error
1010-
assert_raises(Exception, metric, y1, y2,
1011-
sample_weight=np.hstack([sample_weight, sample_weight]))
1009+
# Check that if number of samples in y_true and sample_weight are not
1010+
# equal, meaningful error is raised.
1011+
error_message = ("Found input variables with inconsistent numbers of "
1012+
"samples: [{}, {}, {}]".format(
1013+
_num_samples(y1), _num_samples(y2),
1014+
_num_samples(sample_weight) * 2))
1015+
assert_raise_message(ValueError, error_message, metric, y1, y2,
1016+
sample_weight=np.hstack([sample_weight,
1017+
sample_weight]))
10121018

10131019

10141020
def test_sample_weight_invariance(n_samples=50):

0 commit comments

Comments
 (0)
0