8000 ENH Add sample weight support to matthews_corrcoef metric · scikit-learn/scikit-learn@c7d3e0e · GitHub
[go: up one dir, main page]

Skip to content

Commit c7d3e0e

Browse files
raghavrvMechCoder
authored andcommitted
ENH Add sample weight support to matthews_corrcoef metric
TST Add NRT to test the new matthew_corrcoef impl with np.corrcoef TST Remove matthews_corrcoef metric from METRICS_WITHOUT_SAMPLE_WEIGHT TST/ENH Test some boundary cases for the mcc function w+w/o s_w ENH/PERF Store the standardized arrays to avoid repeated computations ENH Raise warning if the variance of vectors is 0 ENH/TST Must test mathews corrcoef score for sample_weights invariance
1 parent 9358777 commit c7d3e0e

File tree

4 files changed

+112
-31
lines changed

4 files changed

+112
-31
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ Enhancements
7373
- :class:`multiclass.OneVsOneClassifier` and :class:`multiclass.OneVsRestClassifier`
7474
now support ``partial_fit``. By `Asish Panda`_ and `Philipp Dowling`_.
7575

76+
- Add ``sample_weight`` parameter to :func:`metrics.matthews_corrcoef`.
77+
By `Jatin Shah`_ and `Raghav R V`_.
78+
7679
Bug fixes
7780
.........
7881

@@ -3925,7 +3928,7 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
39253928

39263929
.. _Matteo Visconti di Oleggio Castello: http://www.mvdoc.me
39273930

3928-
.. _Raghav R V: https://github.com/ragv
3931+
.. _Raghav R V: https://github.com/rvraghav93
39293932

39303933
.. _Trevor Stephens: http://trevorstephens.com/
39313934

sklearn/metrics/classification.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def jaccard_similarity_score(y_true, y_pred, normalize=True,
398398
return _weighted_sum(score, sample_weight, normalize)
399399

400400

401-
def matthews_corrcoef(y_true, y_pred):
401+
def matthews_corrcoef(y_true, y_pred, sample_weight=None):
402402
"""Compute the Matthews correlation coefficient (MCC) for binary classes
403403
404404
The Matthews correlation coefficient is used in machine learning as a
@@ -423,6 +423,9 @@ def matthews_corrcoef(y_true, y_pred):
423423
y_pred : array, shape = [n_samples]
424424
Estimated targets as returned by a classifier.
425425
426+
sample_weight : array-like of shape = [n_samples], default None
427+
Sample weights.
428+
426429
Returns
427430
-------
428431
mcc : float
@@ -457,8 +460,17 @@ def matthews_corrcoef(y_true, y_pred):
457460
lb.fit(np.hstack([y_true, y_pred]))
458461
y_true = lb.transform(y_true)
459462
y_pred = lb.transform(y_pred)
460-
with np.errstate(invalid='ignore'):
461-
mcc = np.corrcoef(y_true, y_pred)[0, 1]
463+
mean_yt = np.average(y_true, weights=sample_weight)
464+
mean_yp = np.average(y_pred, weights=sample_weight)
465+
466+
y_true_u_cent = y_true - mean_yt
467+
y_pred_u_cent = y_pred - mean_yp
468+
469+
cov_ytyp = np.average(y_true_u_cent * y_pred_u_cent, weights=sample_weight)
470+
var_yt = np.average(y_true_u_cent ** 2, weights=sample_weight)
471+
var_yp = np.average(y_pred_u_cent ** 2, weights=sample_weight)
472+
473+
mcc = cov_ytyp / np.sqrt(var_yt * var_yp)
462474

463475
if np.isnan(mcc):
464476
return 0.

sklearn/metrics/tests/test_classification.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,58 @@ def test_matthews_corrcoef_nan():
331331
assert_equal(matthews_corrcoef([0, 0], [0, 1]), 0.0)
332332

333333

334+
def test_matthews_corrcoef_against_numpy_corrcoef():
335+
rng = np.random.RandomState(0)
336+
y_true = rng.randint(0, 2, size=20)
337+
y_pred = rng.randint(0, 2, size=20)
338+
339+
assert_almost_equal(matthews_corrcoef(y_true, y_pred),
340+
np.corrcoef(y_true, y_pred)[0, 1], 10)
341+
342+
343+
def test_matthews_corrcoef():
344+
rng = np.random.RandomState(0)
345+
y_true = ["a" if i == 0 else "b" for i in rng.randint(0, 2, size=20)]
346+
347+
# corrcoef of same vectors must be 1
348+
assert_almost_equal(matthews_corrcoef(y_true, y_true), 1.0)
349+
350+
# corrcoef, when the two vectors are opposites of each other, should be -1
351+
y_true_inv = ["b" if i == "a" else "a" for i in y_true]
352+
353+
assert_almost_equal(matthews_corrcoef(y_true, y_true_inv), -1)
354+
y_true_inv2 = label_binarize(y_true, ["a", "b"]) * -1
355+
assert_almost_equal(matthews_corrcoef(y_true, y_true_inv2), -1)
356+
357+
# For the zero vector case, the corrcoef cannot be calculated and should
358+
# result in a RuntimeWarning
359+
mcc = assert_warns_message(RuntimeWarning, 'invalid value encountered',
360+
matthews_corrcoef, [0, 0, 0, 0], [0, 0, 0, 0])
361+
362+
# But will output 0
363+
assert_almost_equal(mcc, 0.)
364+
365+
# And also for any other vector with 0 variance
366+
mcc = assert_warns_message(RuntimeWarning, 'invalid value encountered',
367+
matthews_corrcoef, y_true,
368+
rng.randint(-100, 100) * np.ones(20, dtype=int))
369+
370+
# But will output 0
371+
assert_almost_equal(mcc, 0.)
372+
373+
# These two vectors have 0 correlation and hence mcc should be 0
374+
y_1 = [1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1]
375+
y_2 = [1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1]
376+
assert_almost_equal(matthews_corrcoef(y_1, y_2), 0.)
377+
378+
# Check that sample weight is able to selectively exclude
379+
mask = [1] * 10 + [0] * 10
380+
# Now the first half of the vector elements are alone given a weight of 1
381+
# and hence the mcc will not be a perfect 0 as in the previous case
382+
assert_raises(AssertionError, assert_almost_equal,
383+
matthews_corrcoef(y_1, y_2, sample_weight=mask), 0.)
384+
385+
334386
def test_precision_recall_f1_score_multiclass():
335387
# Test Precision Recall and F1 Score for multiclass classification task
336388
y_true, y_pred, _ = make_prediction(binary=False)

sklearn/metrics/tests/test_common.py

Lines changed: 41 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -187,26 +187,41 @@
187187
# When you add a new metric or functionality, check if a general test
188188
# is already written.
189189

190-
# Metric undefined with "binary" or "multiclass" input
191-
METRIC_UNDEFINED_MULTICLASS = [
192-
"samples_f0.5_score", "samples_f1_score", "samples_f2_score",
193-
"samples_precision_score", "samples_recall_score",
190+
# Those metrics don't support binary inputs
191+
METRIC_UNDEFINED_BINARY = [
192+
"samples_f0.5_score",
193+
"samples_f1_score",
194+
"samples_f2_score",
195+
"samples_precision_score",
196+
"samples_recall_score",
197+
"coverage_error",
194198

195-
# Those metrics don't support multiclass outputs
196-
"average_precision_score", "weighted_average_precision_score",
197-
"micro_average_precision_score", "macro_average_precision_score",
199+
"roc_auc_score",
200+
"micro_roc_auc",
201+
"weighted_roc_auc",
202+
"macro_roc_auc",
203+
"samples_roc_auc",
204+
205+
"average_precision_score",
206+
"weighted_average_precision_score",
207+
"micro_average_precision_score",
208+
"macro_average_precision_score",
198209
"samples_average_precision_score",
199210

211+
"label_ranking_loss",
200212
"label_ranking_average_precision_score",
213+
]
201214

202-
"roc_auc_score", "micro_roc_auc", "weighted_roc_auc",
203-
"macro_roc_auc", "samples_roc_auc",
204-
205-
"coverage_error",
215+
# Those metrics don't support multiclass inputs
216+
METRIC_UNDEFINED_MULTICLASS = [
206217
"brier_score_loss",
207-
"label_ranking_loss",
218+
"matthews_corrcoef_score",
208219
]
209220

221+
# Metric undefined with "binary" or "multiclass" input
222+
METRIC_UNDEFINED_BINARY_MULTICLASS = set(METRIC_UNDEFINED_BINARY).union(
223+
set(METRIC_UNDEFINED_MULTICLASS))
224+
210225
# Metrics with an "average" argument
211226
METRICS_WITH_AVERAGING = [
212227
"precision_score", "recall_score", "f1_score", "f2_score", "f0.5_score"
@@ -346,7 +361,6 @@
346361
METRICS_WITHOUT_SAMPLE_WEIGHT = [
347362
"cohen_kappa_score",
348363
"confusion_matrix",
349-
"matthews_corrcoef_score",
350364
"median_absolute_error",
351365
]
352366

@@ -359,10 +373,9 @@ def test_symmetry():
359373
y_pred = random_state.randint(0, 2, size=(20, ))
360374

361375
# We shouldn't forget any metrics
362-
assert_equal(set(SYMMETRIC_METRICS).union(NOT_SYMMETRIC_METRICS,
363-
THRESHOLDED_METRICS,
364-
METRIC_UNDEFINED_MULTICLASS),
365-
set(ALL_METRICS))
376+
assert_equal(set(SYMMETRIC_METRICS).union(
377+
NOT_SYMMETRIC_METRICS, THRESHOLDED_METRICS,
378+
METRIC_UNDEFINED_BINARY_MULTICLASS), set(ALL_METRICS))
366379

367380
assert_equal(
368381
set(SYMMETRIC_METRICS).intersection(set(NOT_SYMMETRIC_METRICS)),
@@ -390,7 +403,7 @@ def test_sample_order_invariance():
390403
y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0)
391404

392405
for name, metric in ALL_METRICS.items():
393-
if name in METRIC_UNDEFINED_MULTICLASS:
406+
if name in METRIC_UNDEFINED_BINARY_MULTICLASS:
394407
continue
395408

396409
assert_almost_equal(metric(y_true, y_pred),
@@ -457,7 +470,7 @@ def test_format_invariance_with_1d_vectors():
457470< F438 code class="diff-text syntax-highlighted-line">
y2_row = np.reshape(y2_1d, (1, -1))
458471

459472
for name, metric in ALL_METRICS.items():
460-
if name in METRIC_UNDEFINED_MULTICLASS:
473+
if name in METRIC_UNDEFINED_BINARY_MULTICLASS:
461474
continue
462475

463476
measure = metric(y1, y2)
@@ -532,7 +545,7 @@ def test_invariance_string_vs_numbers_labels():
532545
labels_str = ["eggs", "spam"]
533546

534547
for name, metric in CLASSIFICATION_METRICS.items():
535-
if name in METRIC_UNDEFINED_MULTICLASS:
548+
if name in METRIC_UNDEFINED_BINARY_MULTICLASS:
536549
continue
537550

538551
measure_with_number = metric(y1, y2)
@@ -613,7 +626,8 @@ def check_single_sample_multioutput(name):
613626

614627
def test_single_sample():
615628
for name in ALL_METRICS:
616-
if name in METRIC_UNDEFINED_MULTICLASS or name in THRESHOLDED_METRICS:
629+
if (name in METRIC_UNDEFINED_BINARY_MULTICLASS or
630+
name in THRESHOLDED_METRICS):
617631
# Those metrics are not always defined with one sample
618632
# or in multiclass classification
619633
continue
@@ -915,9 +929,9 @@ def check_sample_weight_invariance(name, metric, y1, y2):
915929
sample_weight=sample_weight.tolist())
916930
assert_almost_equal(
917931
weighted_score, weighted_score_list,
918-
err_msg="Weighted scores for array and list sample_weight input are "
919-
"not equal (%f != %f) for %s" % (
920-
weighted_score, weighted_score_list, name))
932+
err_msg=("Weighted scores for array and list "
933+
"sample_weight input are not equal (%f != %f) for %s") % (
934+
weighted_score, weighted_score_list, name))
921935

922936
# check that integer weights is the same as repeated samples
923937
repeat_weighted_score = metric(
@@ -963,14 +977,14 @@ def check_sample_weight_invariance(name, metric, y1, y2):
963977
def test_sample_weight_invariance(n_samples=50):
964978
random_state = check_random_state(0)
965979

966-
# binary output
980+
# binary
967981
random_state = check_random_state(0)
968982
y_true = random_state.randint(0, 2, size=(n_samples, ))
969983
y_pred = random_state.randint(0, 2, size=(n_samples, ))
970984
y_score = random_state.random_sample(size=(n_samples,))
971985
for name in ALL_METRICS:
972986
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
973-
name in METRIC_UNDEFINED_MULTICLASS):
987+
name in METRIC_UNDEFINED_BINARY):
974988
continue
975989
metric = ALL_METRICS[name]
976990
if name in THRESHOLDED_METRICS:
@@ -985,7 +999,7 @@ def test_sample_weight_invariance(n_samples=50):
985999
y_score = random_state.random_sample(size=(n_samples, 5))
9861000
for name in ALL_METRICS:
9871001
if (name in METRICS_WITHOUT_SAMPLE_WEIGHT or
988-
name in METRIC_UNDEFINED_MULTICLASS):
1002+
name in METRIC_UNDEFINED_BINARY_MULTICLASS):
9891003
continue
9901004
metric = ALL_METRICS[name]
9911005
if name in THRESHOLDED_METRICS:

0 commit comments

Comments
 (0)
0