10000 adjust some tests · scikit-learn/scikit-learn@b2cdc55 · GitHub
[go: up one dir, main page]

Skip to content

Commit b2cdc55

Browse files
committed
adjust some tests
1 parent 74fde88 commit b2cdc55

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

sklearn/utils/tests/test_class_weight.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ def test_compute_class_weight():
1818
y = np.asarray([2, 2, 2, 3, 3, 4])
1919
classes = np.unique(y)
2020
cw = compute_class_weight("auto", classes, y)
21-
assert_almost_equal(cw.sum(), classes.shape)
21+
class_counts = np.bincount(y)[2:]
22+
# total effect of samples is preserved
23+
assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
2224
assert_true(cw[0] < cw[1] < cw[2])
2325

2426

@@ -63,19 +65,21 @@ def test_compute_class_weight_auto_negative():
6365
# Test with unbalanced class labels.
6466
y = np.asarray([-1, 0, 0, -2, -2, -2])
6567
cw = compute_class_weight("auto", classes, y)
66-
assert_almost_equal(cw.sum(), classes.shape)
68+
class_counts = np.bincount(y + 2)
69+
assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
6770
assert_equal(len(cw), len(classes))
68-
assert_array_almost_equal(cw, np.array([0.545, 1.636, 0.818]), decimal=3)
71+
assert_array_almost_equal(cw, [2. / 3, 2., 1.])
6972

7073

7174
def test_compute_class_weight_auto_unordered():
7275
"""Test compute_class_weight when classes are unordered"""
7376
classes = np.array([1, 0, 3])
7477
y = np.asarray([1, 0, 0, 3, 3, 3])
7578
cw = compute_class_weight("auto", classes, y)
76-
assert_almost_equal(cw.sum(), classes.shape)
79+
class_counts = np.bincount(y)[classes]
80+
assert_almost_equal(np.dot(cw, class_counts), y.shape[0])
7781
assert_equal(len(cw), len(classes))
78-
assert_array_almost_equal(cw, np.array([1.636, 0.818, 0.545]), decimal=3)
82+
assert_array_almost_equal(cw, [2., 1., 2. / 3])
7983

8084

8185
def test_compute_sample_weight():
@@ -97,8 +101,8 @@ def test_compute_sample_weight():
97101
# Test with unbalanced classes
98102
y = np.asarray([1, 1, 1, 2, 2, 2, 3])
99103
sample_weight = compute_sample_weight("auto", y)
100-
expected = np.asarray([.6, .6, .6, .6, .6, .6, 1.8])
101-
assert_array_almost_equal(sample_weight, expected)
104+
expected = np.array([0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 0.7777, 2.3333])
105+
assert_array_almost_equal(sample_weight, expected, decimal=4)
102106

103107
# Test with `None` weights
104108
sample_weight = compute_sample_weight(None, y)
@@ -117,7 +121,7 @@ def test_compute_sample_weight():
117121
# Test with multi-output of unbalanced classes
118122
y = np.asarray([[1, 0], [1, 0], [1, 0], [2, 1], [2, 1], [2, 1], [3, -1]])
119123
sample_weight = compute_sample_weight("auto", y)
120-
assert_array_almost_equal(sample_weight, expected ** 2)
124+
assert_array_almost_equal(sample_weight, expected ** 2, decimal=3)
121125

122126

123127
def test_compute_sample_weight_with_subsample():
@@ -135,12 +139,13 @@ def test_compute_sample_weight_with_subsample():
135139
# Test with a subsample
136140
y = np.asarray([1, 1, 1, 2, 2, 2])
137141
sample_weight = compute_sample_weight("auto", y, range(4))
138-
assert_array_almost_equal(sample_weight, [.5, .5, .5, 1.5, 1.5, 1.5])
142+
assert_array_almost_equal(sample_weight, [2. / 3, 2. / 3,
143+
2. / 3, 2., 2., 2.])
139144

140145
# Test with a bootstrap subsample
141146
y = np.asarray([1, 1, 1, 2, 2, 2])
142147
sample_weight = compute_sample_weight("auto", y, [0, 1, 1, 2, 2, 3])
143-
expected = np.asarray([1 / 3., 1 / 3., 1 / 3., 5 / 3., 5 / 3., 5 / 3.])
148+
expected = np.asarray([0.6, 0.6, 0.6, 3., 3., 3.])
144149
assert_array_almost_equal(sample_weight, expected)
145150

146151
# Test with a bootstrap subsample for multi-output

0 commit comments

Comments
 (0)
0