8000 Revert "FIX _estimate_mi discrete_features str and value check (#13497)" · xhluca/scikit-learn@f36c7f1 · GitHub
[go: up one dir, main page]

Skip to content

Commit f36c7f1

Browse files
author
Xing
authored
Revert "FIX _estimate_mi discrete_features str and value check (scikit-learn#13497)"
This reverts commit ee08cd0.
1 parent 65984cb commit f36c7f1

File tree

2 files changed

+11
-21
lines changed

2 files changed

+11
-21
lines changed

sklearn/feature_selection/mutual_info_.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..preprocessing import scale
1111
from ..utils import check_random_state
1212
from ..utils.fixes import _astype_copy_false
13-
from ..utils.validation import check_array, check_X_y
13+
from ..utils.validation import check_X_y
1414
from ..utils.multiclass import check_classification_targets
1515

1616

@@ -247,16 +247,14 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False,
247247
X, y = check_X_y(X, y, accept_sparse='csc', y_numeric=not discrete_target)
248248
n_samples, n_features = X.shape
249249

250-
if isinstance(discrete_features, (str, bool)):
251-
if isinstance(discrete_features, str):
252-
if discrete_features == 'auto':
253-
discrete_features = issparse(X)
254-
else:
255-
raise ValueError("Invalid string value for discrete_features.")
250+
if discrete_features == 'auto':
251+
discrete_features = issparse(X)
252+
253+
if isinstance(discrete_features, bool):
256254
discrete_mask = np.empty(n_features, dtype=bool)
257255
discrete_mask.fill(discrete_features)
258256
else:
259-
discrete_features = check_array(discrete_features, ensure_2d=False)
257+
discrete_features = np.asarray(discrete_features)
260258
if discrete_features.dtype != 'bool':
261259
discrete_mask = np.zeros(n_features, dtype=bool)
262260
discrete_mask[discrete_features] = True

sklearn/feature_selection/tests/test_mutual_info.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -183,26 +183,18 @@ def test_mutual_info_options():
183183
X_csr = csr_matrix(X)
184184

185185
for mutual_info in (mutual_info_regression, mutual_info_classif):
186-
assert_raises(ValueError, mutual_info, X_csr, y,
186+
assert_raises(ValueError, mutual_info_regression, X_csr, y,
187187
discrete_features=False)
188-
assert_raises(ValueError, mutual_info, X, y,
189-
discrete_features='manual')
190-
assert_raises(ValueError, mutual_info, X_csr, y,
191-
discrete_features=[True, False, True])
192-
assert_raises(IndexError, mutual_info, X, y,
193-
discrete_features=[True, False, True, False])
194-
assert_raises(IndexError, mutual_info, X, y, discrete_features=[1, 4])
195188

196189
mi_1 = mutual_info(X, y, discrete_features='auto', random_state=0)
197190
mi_2 = mutual_info(X, y, discrete_features=False, random_state=0)
198-
mi_3 = mutual_info(X_csr, y, discrete_features='auto', random_state=0)
199-
mi_4 = mutual_info(X_csr, y, discrete_features=True, random_state=0)
200-
mi_5 = mutual_info(X, y, discrete_features=[True, False, True],
191+
192+
mi_3 = mutual_info(X_csr, y, discrete_features='auto',
193+
random_state=0)
194+
mi_4 = mutual_info(X_csr, y, discrete_features=True,
201195
random_state=0)
202-
mi_6 = mutual_info(X, y, discrete_features=[0, 2], random_state=0)
203196

204197
assert_array_equal(mi_1, mi_2)
205198
assert_array_equal(mi_3, mi_4)
206-
assert_array_equal(mi_5, mi_6)
207199

208200
assert not np.allclose(mi_1, mi_3)

0 commit comments

Comments
 (0)
0