@@ -18,7 +18,9 @@ def test_compute_class_weight():
1818 y = np .asarray ([2 , 2 , 2 , 3 , 3 , 4 ])
1919 classes = np .unique (y )
2020 cw = compute_class_weight ("auto" , classes , y )
21- assert_almost_equal (cw .sum (), classes .shape )
21+ class_counts = np .bincount (y )[2 :]
22+ # total effect of samples is preserved
23+ assert_almost_equal (np .dot (cw , class_counts ), y .shape [0 ])
2224 assert_true (cw [0 ] < cw [1 ] < cw [2 ])
2325
2426
@@ -63,19 +65,21 @@ def test_compute_class_weight_auto_negative():
6365 # Test with unbalanced class labels.
6466 y = np .asarray ([- 1 , 0 , 0 , - 2 , - 2 , - 2 ])
6567 cw = compute_class_weight ("auto" , classes , y )
66- assert_almost_equal (cw .sum (), classes .shape )
68+ class_counts = np .bincount (y + 2 )
69+ assert_almost_equal (np .dot (cw , class_counts ), y .shape [0 ])
6770 assert_equal (len (cw ), len (classes ))
68- assert_array_almost_equal (cw , np . array ([ 0.545 , 1.636 , 0.818 ]), decimal = 3 )
71+ assert_array_almost_equal (cw , [ 2. / 3 , 2. , 1. ] )
6972
7073
7174def test_compute_class_weight_auto_unordered ():
7275 """Test compute_class_weight when classes are unordered"""
7376 classes = np .array ([1 , 0 , 3 ])
7477 y = np .asarray ([1 , 0 , 0 , 3 , 3 , 3 ])
7578 cw = compute_class_weight ("auto" , classes , y )
76- assert_almost_equal (cw .sum (), classes .shape )
79+ class_counts = np .bincount (y )[classes ]
80+ assert_almost_equal (np .dot (cw , class_counts ), y .shape [0 ])
7781 assert_equal (len (cw ), len (classes ))
78- assert_array_almost_equal (cw , np . array ([ 1.636 , 0.818 , 0.545 ]), decimal = 3 )
82+ assert_array_almost_equal (cw , [ 2. , 1. , 2. / 3 ] )
7983
8084
8185def test_compute_sample_weight ():
@@ -97,8 +101,8 @@ def test_compute_sample_weight():
97101 # Test with unbalanced classes
98102 y = np .asarray ([1 , 1 , 1 , 2 , 2 , 2 , 3 ])
99103 sample_weight = compute_sample_weight ("auto" , y )
100- expected = np .asarray ([ .6 , .6 , .6 , .6 , .6 , .6 , 1.8 ])
101- assert_array_almost_equal (sample_weight , expected )
104+ expected = np .array ([ 0.7777 , 0.7777 , 0.7777 , 0.7777 , 0.7777 , 0.7777 , 2.3333 ])
105+ assert_array_almost_equal (sample_weight , expected , decimal = 4 )
102106
103107 # Test with `None` weights
104108 sample_weight = compute_sample_weight (None , y )
@@ -117,7 +121,7 @@ def test_compute_sample_weight():
117121 # Test with multi-output of unbalanced classes
118122 y = np .asarray ([[1 , 0 ], [1 , 0 ], [1 , 0 ], [2 , 1 ], [2 , 1 ], [2 , 1 ], [3 , - 1 ]])
119123 sample_weight = compute_sample_weight ("auto" , y )
120- assert_array_almost_equal (sample_weight , expected ** 2 )
124+ assert_array_almost_equal (sample_weight , expected ** 2 , decimal = 3 )
121125
122126
123127def test_compute_sample_weight_with_subsample ():
@@ -135,12 +139,13 @@ def test_compute_sample_weight_with_subsample():
135139 # Test with a subsample
136140 y = np .asarray ([1 , 1 , 1 , 2 , 2 , 2 ])
137141 sample_weight = compute_sample_weight ("auto" , y , range (4 ))
138- assert_array_almost_equal (sample_weight , [.5 , .5 , .5 , 1.5 , 1.5 , 1.5 ])
142+ assert_array_almost_equal (sample_weight , [2. / 3 , 2. / 3 ,
143+ 2. / 3 , 2. , 2. , 2. ])
139144
140145 # Test with a bootstrap subsample
141146 y = np .asarray ([1 , 1 , 1 , 2 , 2 , 2 ])
142147 sample_weight = compute_sample_weight ("auto" , y , [0 , 1 , 1 , 2 , 2 , 3 ])
143- expected = np .asarray ([1 / 3. , 1 / 3. , 1 / 3. , 5 / 3. , 5 / 3. , 5 / 3. ])
148+ expected = np .asarray ([0.6 , 0.6 , 0.6 , 3. , 3. , 3. ])
144149 assert_array_almost_equal (sample_weight , expected )
145150
146151 # Test with a bootstrap subsample for multi-output
0 commit comments