8000 weighted metrics: fix sample_weight handling for average=samples · scikit-learn/scikit-learn@24c9340 · GitHub
[go: up one dir, main page]

Skip to content

Commit 24c9340

Browse files
committed
weighted metrics: fix sample_weight handling for average=samples
1 parent 1d5ba2a commit 24c9340

File tree

2 files changed

+58
-34
lines changed

2 files changed

+58
-34
lines changed

sklearn/metrics/metrics.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -472,34 +472,48 @@ def _average_binary_score(binary_metric, y_true, y_score, average,
472472

473473
y_true, y_score = check_arrays(y_true, y_score)
474474

475+
not_average_axis = 1
476+
score_weight = sample_weight
477+
average_weight = None
478+
475479
if average == "micro":
480+
if score_weight is not None:
481+
score_weight = np.repeat(score_weight, y_true.shape[1])
476482
y_true = y_true.ravel()
477483
y_score = y_score.ravel()
478484

479-
if average == 'weighted':
480-
weights = np.sum(y_true, axis=0)
481-
if weights.sum() == 0:
485+
elif average == 'weighted':
486+
if score_weight is not None:
487+
average_weight = np.sum(np.multiply(
488+
y_true, np.reshape(score_weight, (-1, 1))), axis=0)
489+
else:
490+
average_weight = np.sum(y_true, axis=0)
491+
if average_weight.sum() == 0:
482492
return 0
483-
else:
484-
weights = None
493+
494+
elif average == 'samples':
495+
# swap average_weight <-> score_weight
496+
average_weight = score_weight
497+
score_weight = None
498+
not_average_axis = 0
485499

486500
if y_true.ndim == 1:
487501
y_true = y_true.reshape((-1, 1))
488502

489503
if y_score.ndim == 1:
490504
y_score = y_score.reshape((-1, 1))
491505

492-
not_average_axis = 0 if average == 'samples' else 1
493506
n_classes = y_score.shape[not_average_axis]
494507
score = np.zeros((n_classes,))
495508
for c in range(n_classes):
496509
y_true_c = y_true.take([c], axis=not_average_axis).ravel()
497510
y_score_c = y_score.take([c], axis=not_average_axis).ravel()
498-
score[c] = binary_metric(y_true_c, y_score_c)
511+
score[c] = binary_metric(y_true_c, y_score_c,
512+
sample_weight=score_weight)
499513

500514
# Average the results
501515
if average is not None:
502-
return np.average(score, weights=weights)
516+
return np.average(score, weights=average_weight)
503517
else:
504518
return score
505519

@@ -1687,20 +1701,20 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
16871701
y_pred = y_pred == 1
16881702

16891703
if sample_weight is None:
1690-
sample_weight = 1
1704+
sum_weight = 1
16911705
dtype = int
16921706
else:
1693-
sample_weight = np.expand_dims(sample_weight, 1)
1707+
sum_weight = np.expand_dims(sample_weight, 1)
16941708
dtype = float
16951709

16961710
sum_axis = 1 if average == 'samples' else 0
16971711
tp_sum = np.sum(
1698-
np.multiply(np.logical_and(y_true, y_pred), sample_weight),
1712+
np.multiply(np.logical_and(y_true, y_pred), sum_weight),
16991713
axis=sum_axis)
17001714
pred_sum = np.sum(
1701-
np.multiply(y_pred, sample_weight), axis=sum_axis, dtype=dtype)
1715+
np.multiply(y_pred, sum_weight), axis=sum_axis, dtype=dtype)
17021716
true_sum = np.sum(
1703-
np.multiply(y_true, sample_weight), axis=sum_axis, dtype=dtype)
1717+
np.multiply(y_true, sum_weight), axis=sum_axis, dtype=dtype)
17041718

17051719
elif average == 'samples':
17061720
raise ValueError("Sample-based precision, recall, fscore is "
@@ -1778,6 +1792,8 @@ def precision_recall_fscore_support(y_true, y_pred, beta=1.0, labels=None,
17781792
weights = true_sum
17791793
if weights.sum() == 0:
17801794
return 0, 0, 0, None
1795+
elif average == 'samples':
1796+
weights = sample_weight
17811797
else:
17821798
weights = None
17831799

sklearn/metrics/tests/test_metrics.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,14 @@
352352
"samples_precision_score", "samples_recall_score",
353353
]
354354

355+
MULTILABEL_INDICATOR_METRICS_WITH_SAMPLE_WEIGHT = [
356+
"average_precision_score",
357+
"weighted_average_precision_score",
358+
"micro_average_precision_score",
359+
"macro_average_precision_score",
360+
"samples_average_precision_score",
361+
]
362+
355363
# Regression metrics that support multioutput and weighted samples
356364
MULTIOUTPUT_METRICS_WITH_SAMPLE_WEIGHT = [
357365
"mean_squared_error",
@@ -2565,7 +2573,7 @@ def test_averaging_multilabel_all_ones():
25652573
@ignore_warnings
25662574
def check_sample_weight_invariance(name, metric, y1, y2):
25672575
rng = np.random.RandomState(0)
2568-
sample_weight = rng.randint(10, size=len(y1))
2576+
sample_weight = rng.randint(1, 10, size=len(y1))
25692577

25702578
# check that unit weights gives the same score as no weight
25712579
unweighted_score = metric(y1, y2, sample_weight=None)
@@ -2591,14 +2599,13 @@ def check_sample_weight_invariance(name, metric, y1, y2):
25912599
"not equal (%f != %f) for %s" % (
25922600
weighted_score, weighted_score_list, name))
25932601

2594-
if not name.startswith('samples'):
2595-
# check that integer weights is the same as repeated samples
2596-
repeat_weighted_score = metric(
2597-
np.repeat(y1, sample_weight, axis=0),
2598-
np.repeat(y2, sample_weight, axis=0), sample_weight=None)
2599-
assert_almost_equal(
2600-
weighted_score, repeat_weighted_score,
2601-
err_msg="Weighting %s is not equal to repeating samples" % name)
2602+
# check that integer weights is the same as repeated samples
2603+
repeat_weighted_score = metric(
2604+
np.repeat(y1, sample_weight, axis=0),
2605+
np.repeat(y2, sample_weight, axis=0), sample_weight=None)
2606+
assert_almost_equal(
2607+
weighted_score, repeat_weighted_score,
2608+
err_msg="Weighting %s is not equal to repeating samples" % name)
26022609

26032610
if not name.startswith('unnormalized'):
26042611
# check that the score is invariant under scaling of the weights by a
@@ -2612,33 +2619,34 @@ def check_sample_weight_invariance(name, metric, y1, y2):
26122619

26132620

26142621
def test_sample_weight_invariance():
2615-
# generate some data
2622+
# binary
26162623
y1, y2, _ = make_prediction(binary=True)
2617-
26182624
for name in METRICS_WITH_SAMPLE_WEIGHT:
26192625
metric = ALL_METRICS[name]
26202626
yield check_sample_weight_invariance, name, metric, y1, y2
26212627

2622-
# multilabel
2628+
# multilabel sequence
26232629
n_classes = 3
26242630
n_samples = 10
2625-
_, y1_multilabel = make_multilabel_classification(
2631+
_, y1 = make_multilabel_classification(
26262632
n_features=1, n_classes=n_classes,
26272633
random_state=0, n_samples=n_samples)
2628-
_, y2_multilabel = make_multilabel_classification(
2634+
_, y2 = make_multilabel_classification(
26292635
n_features=1, n_classes=n_classes,
26302636
random_state=1, n_samples=n_samples)
2631-
26322637
for name in MULTILABEL_METRICS_WITH_SAMPLE_WEIGHT:
26332638
metric = ALL_METRICS[name]
2634-
yield (check_sample_weight_invariance,
2635-
name, metric, y1_multilabel, y2_multilabel)
2639+
yield (check_sample_weight_invariance, name, metric, y1, y2)
26362640

2637-
# multioutput
2638-
y1_multioutput = np.array([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]])
2639-
y2_multioutput = np.array([[0, 0, 1, 1], [1, 0, 1, 1], [1, 1, 0, 1]])
2641+
# multilabel indicator
2642+
y1 = np.array([[1, 0, 0, 1], [0, 1, 1, 1], [1, 1, 0, 1]])
2643+
y2 = np.array([[0, 0, 1, 1], [1, 0, 1, 1], [1, 1, 0, 1]])
2644+
for name in MULTILABEL_INDICATOR_METRICS_WITH_SAMPLE_WEIGHT:
2645+
metric = ALL_METRICS[name]
2646+
yield (check_sample_weight_invariance, name, metric, y1, y2)
26402647

2648+
# multioutput
26412649
for name in MULTIOUTPUT_METRICS_WITH_SAMPLE_WEIGHT:
26422650
metric = ALL_METRICS[name]
26432651
yield (check_sample_weight_invariance,
2644-
name, metric, y1_multioutput, y2_multioutput)
2652+
name, metric, y1, y2)

0 commit comments

Comments
 (0)
0