8000 TST use global_dtype in feature_selection/tests/test_mutual_info.py … · scikit-learn/scikit-learn@e7d718c · GitHub
[go: up one dir, main page]

Skip to content

Commit e7d718c

Browse files
jjerphanogriseljeremiedbb
authored
TST use global_dtype in feature_selection/tests/test_mutual_info.py (#22677)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Jérémie du Boisberranger <jeremiedbb@users.noreply.github.com>
1 parent 48f363f commit e7d718c

File tree

1 file changed

+35
-25
lines changed

1 file changed

+35
-25
lines changed

sklearn/feature_selection/tests/test_mutual_info.py

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from sklearn.utils import check_random_state
66
from sklearn.utils._testing import (
77
assert_array_equal,
8-
assert_almost_equal,
98
assert_allclose,
109
)
1110
from sklearn.feature_selection._mutual_info import _compute_mi
@@ -22,7 +21,7 @@ def test_compute_mi_dd():
2221
H_xy = -1 / 5 * np.log(1 / 5) - 2 / 5 * np.log(2 / 5) - 2 / 5 * np.log(2 / 5)
2322
I_xy = H_x + H_y - H_xy
2423

25-
assert_almost_equal(_compute_mi(x, y, True, True), I_xy)
24+
assert_allclose(_compute_mi(x, y, x_discrete=True, y_discrete=True), I_xy)
2625

2726

2827
def test_compute_mi_cc(global_dtype):
@@ -54,11 +53,13 @@ def test_compute_mi_cc(global_dtype):
5453
# Theory and computed values won't be very close
5554
# We here check with a large relative tolerance
5655
for n_neighbors in [3, 5, 7]:
57-
I_computed = _compute_mi(x, y, F 8000 alse, False, n_neighbors)
56+
I_computed = _compute_mi(
57+
x, y, x_discrete=False, y_discrete=False, n_neighbors=n_neighbors
58+
)
5859
assert_allclose(I_computed, I_theory, rtol=1e-1)
5960

6061

61-
def test_compute_mi_cd():
62+
def test_compute_mi_cd(global_dtype):
6263
# To test define a joint distribution as follows:
6364
# p(x, y) = p(x) p(y | x)
6465
# X ~ Bernoulli(p)
@@ -80,7 +81,7 @@ def test_compute_mi_cd():
8081
for p in [0.3, 0.5, 0.7]:
8182
x = rng.uniform(size=n_samples) > p
8283

83-
y = np.empty(n_samples)
84+
y = np.empty(n_samples, global_dtype)
8485
mask = x == 0
8586
y[mask] = rng.uniform(-1, 1, size=np.sum(mask))
8687
y[~mask] = rng.uniform(0, 2, size=np.sum(~mask))
@@ -91,32 +92,36 @@ def test_compute_mi_cd():
9192

9293
# Assert the same tolerance.
9394
for n_neighbors in [3, 5, 7]:
94-
I_computed = _compute_mi(x, y, True, False, n_neighbors)
95-
assert_almost_equal(I_computed, I_theory, 1)
95+
I_computed = _compute_mi(
96+
x, y, x_discrete=True, y_discrete=False, n_neighbors=n_neighbors
97+
)
98+
assert_allclose(I_computed, I_theory, rtol=1e-1)
9699

97100

98-
def test_compute_mi_cd_unique_label():
101+
def test_compute_mi_cd_unique_label(global_dtype):
99102
# Test that adding unique label doesn't change MI.
100103
n_samples = 100
101104
x = np.random.uniform(size=n_samples) > 0.5
102105

103-
y = np.empty(n_samples)
106+
y = np.empty(n_samples, global_dtype)
104107
mask = x == 0
105108
y[mask] = np.random.uniform(-1, 1, size=np.sum(mask))
106109
y[~mask] = np.random.uniform(0, 2, size=np.sum(~mask))
107110

108-
mi_1 = _compute_mi(x, y, True, False)
111+
mi_1 = _compute_mi(x, y, x_discrete=True, y_discrete=False)
109112

110113
x = np.hstack((x, 2))
111114
y = np.hstack((y, 10))
112-
mi_2 = _compute_mi(x, y, True, False)
115+
mi_2 = _compute_mi(x, y, x_discrete=True, y_discrete=False)
113116

114-
assert mi_1 == mi_2
117+
assert_allclose(mi_1, mi_2)
115118

116119

117120
# We are going test that feature ordering by MI matches our expectations.
118-
def test_mutual_info_classif_discrete():
119-
X = np.array([[0, 0, 0], [1, 1, 0], [2, 0, 1], [2, 0, 1], [2, 0, 1]])
121+
def test_mutual_info_classif_discrete(global_dtype):
122+
X = np.array(
123+
[[0, 0, 0], [1, 1, 0], [2, 0, 1], [2, 0, 1], [2, 0, 1]], dtype=global_dtype
124+
)
120125
y = np.array([0, 1, 2, 2, 1])
121126

122127
# Here X[:, 0] is the most informative feature, and X[:, 1] is weakly
@@ -125,7 +130,7 @@ def test_mutual_info_classif_discrete():
125130
assert_array_equal(np.argsort(-mi), np.array([0, 2, 1]))
126131

127132

128-
def test_mutual_info_regression():
133+
def test_mutual_info_regression(global_dtype):
129134
# We generate sample from multivariate normal distribution, using
130135
# transformation from initially uncorrelated variables. The zero
131136
# variables after transformation is selected as the target vector,
@@ -136,19 +141,22 @@ def test_mutual_info_regression():
136141
mean = np.zeros(4)
137142

138143
rng = check_random_state(0)
139-
Z = rng.multivariate_normal(mean, cov, size=1000)
144+
Z = rng.multivariate_normal(mean, cov, size=1000).astype(global_dtype, copy=False)
140145
X = Z[:, 1:]
141146
y = Z[:, 0]
142147

143148
mi = mutual_info_regression(X, y, random_state=0)
144149
assert_array_equal(np.argsort(-mi), np.array([1, 2, 0]))
150+
# XXX: should mutual_info_regression be fixed to avoid
151+
# up-casting float32 inputs to float64?
152+
assert mi.dtype == np.float64
145153

146154

147-
def test_mutual_info_classif_mixed():
155+
def test_mutual_info_classif_mixed(global_dtype):
148156
# Here the target is discrete and there are two continuous and one
149157
# discrete feature. The idea of this test is clear from the code.
150158
rng = check_random_state(0)
151-
X = rng.rand(1000, 3)
159+
X = rng.rand(1000, 3).astype(global_dtype, copy=False)
152160
X[:, 1] += X[:, 0]
153161
y = ((0.5 * X[:, 0] + X[:, 2]) > 0.5).astype(int)
154162
X[:, 2] = X[:, 2] > 0.5
@@ -168,9 +176,11 @@ def test_mutual_info_classif_mixed():
168176
assert mi_nn[2] == mi[2]
169177

170178

171-
def test_mutual_info_options():
172-
X = np.array([[0, 0, 0], [1, 1, 0], [2, 0, 1], [2, 0, 1], [2, 0, 1]], dtype=float)
173-
y = np.array([0, 1, 2, 2, 1], dtype=float)
179+
def test_mutual_info_options(global_dtype):
180+
X = np.array(
181+
[[0, 0, 0], [1, 1, 0], [2, 0, 1], [2, 0, 1], [2, 0, 1]], dtype=global_dtype
182+
)
183+
y = np.array([0, 1, 2, 2, 1], dtype=global_dtype)
174184
X_csr = csr_matrix(X)
175185

176186
for mutual_info in (mutual_info_regression, mutual_info_classif):
@@ -192,8 +202,8 @@ def test_mutual_info_options():
192202
mi_5 = mutual_info(X, y, discrete_features=[True, False, True], random_state=0)
193203
mi_6 = mutual_info(X, y, discrete_features=[0, 2], random_state=0)
194204

195-
assert_array_equal(mi_1, mi_2)
196-
assert_array_equal(mi_3, mi_4)
197-
assert_array_equal(mi_5, mi_6)
205+
assert_allclose(mi_1, mi_2)
206+
assert_allclose(mi_3, mi_4)
207+
assert_allclose(mi_5, mi_6)
198208

199-
assert not np.allclose(mi_1, mi_3)
209+
assert not np.allclose(mi_1, mi_3)

0 commit comments

Comments
 (0)
0