@@ -18,7 +18,9 @@ def test_compute_class_weight():
18
18
y = np .asarray ([2 , 2 , 2 , 3 , 3 , 4 ])
19
19
classes = np .unique (y )
20
20
cw = compute_class_weight ("auto" , classes , y )
21
- assert_almost_equal (cw .sum (), classes .shape )
21
+ class_counts = np .bincount (y )[2 :]
22
+ # total effect of samples is preserved
23
+ assert_almost_equal (np .dot (cw , class_counts ), y .shape [0 ])
22
24
assert_true (cw [0 ] < cw [1 ] < cw [2 ])
23
25
24
26
@@ -63,19 +65,21 @@ def test_compute_class_weight_auto_negative():
63
65
# Test with unbalanced class labels.
64
66
y = np .asarray ([- 1 , 0 , 0 , - 2 , - 2 , - 2 ])
65
67
cw = compute_class_weight ("auto" , classes , y )
66
- assert_almost_equal (cw .sum (), classes .shape )
68
+ class_counts = np .bincount (y + 2 )
69
+ assert_almost_equal (np .dot (cw , class_counts ), y .shape [0 ])
67
70
assert_equal (len (cw ), len (classes ))
68
- assert_array_almost_equal (cw , np . array ([ 0.545 , 1.636 , 0.818 ]), decimal = 3 )
71
+ assert_array_almost_equal (cw , [ 2. / 3 , 2. , 1. ] )
69
72
70
73
71
74
def test_compute_class_weight_auto_unordered ():
72
75
"""Test compute_class_weight when classes are unordered"""
73
76
classes = np .array ([1 , 0 , 3 ])
74
77
y = np .asarray ([1 , 0 , 0 , 3 , 3 , 3 ])
75
78
cw = compute_class_weight ("auto" , classes , y )
76
- assert_almost_equal (cw .sum (), classes .shape )
79
+ class_counts = np .bincount (y )[classes ]
80
+ assert_almost_equal (np .dot (cw , class_counts ), y .shape [0 ])
77
81
assert_equal (len (cw ), len (classes ))
78
- assert_array_almost_equal (cw , np . array ([ 1.636 , 0.818 , 0.545 ]), decimal = 3 )
82
+ assert_array_almost_equal (cw , [ 2. , 1. , 2. / 3 ] )
79
83
80
84
81
85
def test_compute_sample_weight ():
@@ -97,8 +101,8 @@ def test_compute_sample_weight():
97
101
# Test with unbalanced classes
98
102
y = np .asarray ([1 , 1 , 1 , 2 , 2 , 2 , 3 ])
99
103
sample_weight = compute_sample_weight ("auto" , y )
100
- expected = np .asarray ([ .6 , .6 , .6 , .6 , .6 , .6 , 1.8 ])
101
- assert_array_almost_equal (sample_weight , expected )
104
+ expected = np .array ([ 0.7777 , 0.7777 , 0.7777 , 0.7777 , 0.7777 , 0.7777 , 2.3333 ])
105
+ assert_array_almost_equal (sample_weight , expected , decimal = 4 )
102
106
103
107
# Test with `None` weights
104
108
sample_weight = compute_sample_weight (None , y )
@@ -117,7 +121,7 @@ def test_compute_sample_weight():
117
121
# Test with multi-output of unbalanced classes
118
122
y = np .asarray ([[1 , 0 ], [1 , 0 ], [1 , 0 ], [2 , 1 ], [2 , 1 ], [2 , 1 ], [3 , - 1 ]])
119
123
sample_weight = compute_sample_weight ("auto" , y )
120
- assert_array_almost_equal (sample_weight , expected ** 2 )
124
+ assert_array_almost_equal (sample_weight , expected ** 2 , decimal = 3 )
121
125
122
126
123
127
def test_compute_sample_weight_with_subsample ():
@@ -135,12 +139,13 @@ def test_compute_sample_weight_with_subsample():
135
139
# Test with a subsample
136
140
y = np .asarray ([1 , 1 , 1 , 2 , 2 , 2 ])
137
141
sample_weight = compute_sample_weight ("auto" , y , range (4 ))
138
- assert_array_almost_equal (sample_weight , [.5 , .5 , .5 , 1.5 , 1.5 , 1.5 ])
142
+ assert_array_almost_equal (sample_weight , [2. / 3 , 2. / 3 ,
143
+ 2. / 3 , 2. , 2. , 2. ])
139
144
140
145
# Test with a bootstrap subsample
141
146
y = np .asarray ([1 , 1 , 1 , 2 , 2 , 2 ])
142
147
sample_weight = compute_sample_weight ("auto" , y , [0 , 1 , 1 , 2 , 2 , 3 ])
143
- expected = np .asarray ([1 / 3. , 1 / 3. , 1 / 3. , 5 / 3. , 5 / 3. , 5 / 3. ])
148
+ expected = np .asarray ([0.6 , 0.6 , 0.6 , 3. , 3. , 3. ])
144
149
assert_array_almost_equal (sample_weight , expected )
145
150
146
151
# Test with a bootstrap subsample for multi-output
0 commit comments