8000 Use more natural class_weight="auto" heuristic · scikit-learn/scikit-learn@3d7ad59 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3d7ad59

Browse files
committed
Use more natural class_weight="auto" heuristic
1 parent 8dbe3f8 commit 3d7ad59

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

sklearn/utils/class_weight.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def compute_class_weight(class_weight, classes, y):
4747
raise ValueError("classes should have valid labels that are in y")
4848

4949
# inversely proportional to the number of samples in the class
50-
recip_freq = 1. / bincount(y_ind)
51-
weight = recip_freq[le.transform(classes)] / np.mean(recip_freq)
50+
recip_freq = len(y) / (len(le.classes_) *
51+
bincount(y_ind).astype(np.float64))
52+
weight = recip_freq[le.transform(classes)]
5253
else:
5354
# user-defined dictionary
5455
weight = np.ones(classes.shape[0], dtype=np.float64, order='C')

sklearn/utils/estimator_checks.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -905,10 +905,9 @@ def check_class_weight_auto_linear_classifier(name, Classifier):
905905
coef_auto = classifier.fit(X, y).coef_.copy()
906906

907907
# Count each label occurrence to reweight manually
908-
mean_weight = (1. / 3 + 1. / 2) / 2
909908
class_weight = {
910-
1: 1. / 3 / mean_weight,
911-
-1: 1. / 2 / mean_weight,
909+
1: 5. / (2 * 3),
910+
-1: 5. / (2 * 2)
912911
}
913912
classifier.set_params(class_weight=class_weight)
914913
coef_manual = classifier.fit(X, y).coef_.copy()

sklearn/utils/tests/test_class_weight.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import numpy as np
22

3+
from sklearn.linear_model import LogisticRegression
4+
from sklearn.datasets import make_blobs
5+
36
from sklearn.utils.class_weight import compute_class_weight
47
from sklearn.utils.class_weight import compute_sample_weight
58

@@ -26,6 +29,27 @@ def test_compute_class_weight_not_present():
2629
assert_raises(ValueError, compute_class_weight, "auto", classes, y)
2730

2831

32+
def test_compute_class_weight_invariance():
33+
# test that results with class_weight="auto" is invariant against
34+
# class imbalance if the number of samples is identical
35+
X, y = make_blobs(centers=2, random_state=0)
36+
# create dataset where class 1 is duplicated twice
37+
X_1 = np.vstack([X] + [X[y == 1]] * 2)
38+
y_1 = np.hstack([y] + [y[y == 1]] * 2)
39+
# create dataset where class 0 is duplicated twice
40+
X_0 = np.vstack([X] + [X[y == 0]] * 2)
41+
y_0 = np.hstack([y] + [y[y == 0]] * 2)
42+
# cuplicate everything
43+
X_ = np.vstack([X] * 2)
44+
y_ = np.hstack([y] * 2)
45+
# results should be identical
46+
logreg1 = LogisticRegression(class_weight="auto").fit(X_1, y_1)
47+
logreg0 = LogisticRegression(class_weight="auto").fit(X_0, y_0)
48+
logreg = LogisticRegression(class_weight="auto").fit(X_, y_)
49+
assert_array_almost_equal(logreg1.coef_, logreg0.coef_)
50+
assert_array_almost_equal(logreg.coef_, logreg0.coef_)
51+
52+
2953
def test_compute_class_weight_auto_negative():
3054
"""Test compute_class_weight when labels are negative"""
3155
# Test with balanced class labels.
@@ -116,7 +140,7 @@ def test_compute_sample_weight_with_subsample():
116140
# Test with a bootstrap subsample
117141
y = np.asarray([1, 1, 1, 2, 2, 2])
118142
sample_weight = compute_sample_weight("auto", y, [0, 1, 1, 2, 2, 3])
119-
expected = np.asarray([1/3., 1/3., 1/3., 5/3., 5/3., 5/3.])
143+
expected = np.asarray([1 / 3., 1 / 3., 1 / 3., 5 / 3., 5 / 3., 5 / 3.])
120144
assert_array_almost_equal(sample_weight, expected)
121145

122146
# Test with a bootstrap subsample for multi-output

0 commit comments

Comments
 (0)
0