5
5
from sklearn .utils import check_random_state
6
6
from sklearn .utils ._testing import (
7
7
assert_array_equal ,
8
- assert_almost_equal ,
9
8
assert_allclose ,
10
9
)
11
10
from sklearn .feature_selection ._mutual_info import _compute_mi
@@ -22,7 +21,7 @@ def test_compute_mi_dd():
22
21
H_xy = - 1 / 5 * np .log (1 / 5 ) - 2 / 5 * np .log (2 / 5 ) - 2 / 5 * np .log (2 / 5 )
23
22
I_xy = H_x + H_y - H_xy
24
23
25
- assert_almost_equal (_compute_mi (x , y , True , True ), I_xy )
24
+ assert_allclose (_compute_mi (x , y , x_discrete = True , y_discrete = True ), I_xy )
26
25
27
26
28
27
def test_compute_mi_cc (global_dtype ):
@@ -54,11 +53,13 @@ def test_compute_mi_cc(global_dtype):
54
53
# Theory and computed values won't be very close
55
54
# We here check with a large relative tolerance
56
55
for n_neighbors in [3 , 5 , 7 ]:
57
- I_computed = _compute_mi (x , y , F
8000
alse , False , n_neighbors )
56
+ I_computed = _compute_mi (
57
+ x , y , x_discrete = False , y_discrete = False , n_neighbors = n_neighbors
58
+ )
58
59
assert_allclose (I_computed , I_theory , rtol = 1e-1 )
59
60
60
61
61
- def test_compute_mi_cd ():
62
+ def test_compute_mi_cd (global_dtype ):
62
63
# To test define a joint distribution as follows:
63
64
# p(x, y) = p(x) p(y | x)
64
65
# X ~ Bernoulli(p)
@@ -80,7 +81,7 @@ def test_compute_mi_cd():
80
81
for p in [0.3 , 0.5 , 0.7 ]:
81
82
x = rng .uniform (size = n_samples ) > p
82
83
83
- y = np .empty (n_samples )
84
+ y = np .empty (n_samples , global_dtype )
84
85
mask = x == 0
85
86
y [mask ] = rng .uniform (- 1 , 1 , size = np .sum (mask ))
86
87
y [~ mask ] = rng .uniform (0 , 2 , size = np .sum (~ mask ))
@@ -91,32 +92,36 @@ def test_compute_mi_cd():
91
92
92
93
# Assert the same tolerance.
93
94
for n_neighbors in [3 , 5 , 7 ]:
94
- I_computed = _compute_mi (x , y , True , False , n_neighbors )
95
- assert_almost_equal (I_computed , I_theory , 1 )
95
+ I_computed = _compute_mi (
96
+ x , y , x_discrete = True , y_discrete = False , n_neighbors = n_neighbors
97
+ )
98
+ assert_allclose (I_computed , I_theory , rtol = 1e-1 )
96
99
97
100
98
- def test_compute_mi_cd_unique_label ():
101
+ def test_compute_mi_cd_unique_label (global_dtype ):
99
102
# Test that adding unique label doesn't change MI.
100
103
n_samples = 100
101
104
x = np .random .uniform (size = n_samples ) > 0.5
102
105
103
- y = np .empty (n_samples )
106
+ y = np .empty (n_samples , global_dtype )
104
107
mask = x == 0
105
108
y [mask ] = np .random .uniform (- 1 , 1 , size = np .sum (mask ))
106
109
y [~ mask ] = np .random .uniform (0 , 2 , size = np .sum (~ mask ))
107
110
108
- mi_1 = _compute_mi (x , y , True , False )
111
+ mi_1 = _compute_mi (x , y , x_discrete = True , y_discrete = False )
109
112
110
113
x = np .hstack ((x , 2 ))
111
114
y = np .hstack ((y , 10 ))
112
- mi_2 = _compute_mi (x , y , True , False )
115
+ mi_2 = _compute_mi (x , y , x_discrete = True , y_discrete = False )
113
116
114
- assert mi_1 == mi_2
117
+ assert_allclose ( mi_1 , mi_2 )
115
118
116
119
117
120
# We are going test that feature ordering by MI matches our expectations.
118
- def test_mutual_info_classif_discrete ():
119
- X = np .array ([[0 , 0 , 0 ], [1 , 1 , 0 ], [2 , 0 , 1 ], [2 , 0 , 1 ], [2 , 0 , 1 ]])
121
+ def test_mutual_info_classif_discrete (global_dtype ):
122
+ X = np .array (
123
+ [[0 , 0 , 0 ], [1 , 1 , 0 ], [2 , 0 , 1 ], [2 , 0 , 1 ], [2 , 0 , 1 ]], dtype = global_dtype
124
+ )
120
125
y = np .array ([0 , 1 , 2 , 2 , 1 ])
121
126
122
127
# Here X[:, 0] is the most informative feature, and X[:, 1] is weakly
@@ -125,7 +130,7 @@ def test_mutual_info_classif_discrete():
125
130
assert_array_equal (np .argsort (- mi ), np .array ([0 , 2 , 1 ]))
126
131
127
132
128
- def test_mutual_info_regression ():
133
+ def test_mutual_info_regression (global_dtype ):
129
134
# We generate sample from multivariate normal distribution, using
130
135
# transformation from initially uncorrelated variables. The zero
131
136
# variables after transformation is selected as the target vector,
@@ -136,19 +141,22 @@ def test_mutual_info_regression():
136
141
mean = np .zeros (4 )
137
142
138
143
rng = check_random_state (0 )
139
- Z = rng .multivariate_normal (mean , cov , size = 1000 )
144
+ Z = rng .multivariate_normal (mean , cov , size = 1000 ). astype ( global_dtype , copy = False )
140
145
X = Z [:, 1 :]
141
146
y = Z [:, 0 ]
142
147
143
148
mi = mutual_info_regression (X , y , random_state = 0 )
144
149
assert_array_equal (np .argsort (- mi ), np .array ([1 , 2 , 0 ]))
150
+ # XXX: should mutual_info_regression be fixed to avoid
151
+ # up-casting float32 inputs to float64?
152
+ assert mi .dtype == np .float64
145
153
146
154
147
- def test_mutual_info_classif_mixed ():
155
+ def test_mutual_info_classif_mixed (global_dtype ):
148
156
# Here the target is discrete and there are two continuous and one
149
157
# discrete feature. The idea of this test is clear from the code.
150
158
rng = check_random_state (0 )
151
- X = rng .rand (1000 , 3 )
159
+ X = rng .rand (1000 , 3 ). astype ( global_dtype , copy = False )
152
160
X [:, 1 ] += X [:, 0 ]
153
161
y = ((0.5 * X [:, 0 ] + X [:, 2 ]) > 0.5 ).astype (int )
154
162
X [:, 2 ] = X [:, 2 ] > 0.5
@@ -168,9 +176,11 @@ def test_mutual_info_classif_mixed():
168
176
assert mi_nn [2 ] == mi [2 ]
169
177
170
178
171
- def test_mutual_info_options ():
172
- X = np .array ([[0 , 0 , 0 ], [1 , 1 , 0 ], [2 , 0 , 1 ], [2 , 0 , 1 ], [2 , 0 , 1 ]], dtype = float )
173
- y = np .array ([0 , 1 , 2 , 2 , 1 ], dtype = float )
179
+ def test_mutual_info_options (global_dtype ):
180
+ X = np .array (
181
+ [[0 , 0 , 0 ], [1 , 1 , 0 ], [2 , 0 , 1 ], [2 , 0 , 1 ], [2 , 0 , 1 ]], dtype = global_dtype
182
+ )
183
+ y = np .array ([0 , 1 , 2 , 2 , 1 ], dtype = global_dtype )
174
184
X_csr = csr_matrix (X )
175
185
176
186
for mutual_info in (mutual_info_regression , mutual_info_classif ):
@@ -192,8 +202,8 @@ def test_mutual_info_options():
192
202
mi_5 = mutual_info (X , y , discrete_features = [True , False , True ], random_state = 0 )
193
203
mi_6 = mutual_info (X , y , discrete_features = [0 , 2 ], random_state = 0 )
194
204
195
- assert_array_equal (mi_1 , mi_2 )
196
- assert_array_equal (mi_3 , mi_4 )
197
- assert_array_equal (mi_5 , mi_6 )
205
+ assert_allclose (mi_1 , mi_2 )
206
+ assert_allclose (mi_3 , mi_4 )
207
+ assert_allclose (mi_5 , mi_6 )
198
208
199
- assert not np .allclose (mi_1 , mi_3 )
209
+ assert not np .allclose (mi_1 , mi_3 )
0 commit comments