8000 FIX _estimate_mi discrete_features str and value check (#13497) · xhluca/scikit-learn@ee08cd0 · GitHub
[go: up one dir, main page]

Skip to content

Commit ee08cd0

Browse files
hermidalcXing
authored and
Xing
committed
FIX _estimate_mi discrete_features str and value check (scikit-learn#13497)
* discrete_features str and value check * Update if logic * Add discrete_features bad str value test * Remove unnecessary nested isinstance str check * Add back nested isinstance str check * New/updates to tests * Add v0.21 whats new entry * Undo v0.21 whats new entry
1 parent cd7764d commit ee08cd0

File tree

2 files changed

+21
-11
lines changed

2 files changed

+21
-11
lines changed

sklearn/feature_selection/mutual_info_.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ..preprocessing import scale
1111
from ..utils import check_random_state
1212
from ..utils.fixes import _astype_copy_false
13-
from ..utils.validation import check_X_y
13+
from ..utils.validation import check_array, check_X_y
1414
from ..utils.multiclass import check_classification_targets
1515

1616

@@ -247,14 +247,16 @@ def _estimate_mi(X, y, discrete_features='auto', discrete_target=False,
247247
X, y = check_X_y(X, y, accept_sparse='csc', y_numeric=not discrete_target)
248248
n_samples, n_features = X.shape
249249

250-
if discrete_features == 'auto':
251-
discrete_features = issparse(X)
252-
253-
if isinstance(discrete_features, bool):
250+
if isinstance(discrete_features, (str, bool)):
251+
if isinstance(discrete_features, str):
252+
if discrete_features == 'auto':
253+
discrete_features = issparse(X)
254+
else:
255+
raise ValueError("Invalid string value for discrete_features.")
254256
discrete_mask = np.empty(n_features, dtype=bool)
255257
discrete_mask.fill(discrete_features)
256258
else:
257-
discrete_features = np.asarray(discrete_features)
259+
discrete_features = check_array(discrete_features, ensure_2d=False)
258260
if discrete_features.dtype != 'bool':
259261
discrete_mask = np.zeros(n_features, dtype=bool)
260262
discrete_mask[discrete_features] = True

sklearn/feature_selection/tests/test_mutual_info.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,18 +183,26 @@ def test_mutual_info_options():
183183
X_csr = csr_matrix(X)
184184

185185
for mutual_info in (mutual_info_regression, mutual_info_classif):
186-
assert_raises(ValueError, mutual_info_regression, X_csr, y,
186+
assert_raises(ValueError, mutual_info, X_csr, y,
187187
discrete_features=False)
188+
assert_raises(ValueError, mutual_info, X, y,
189+
discrete_features='manual')
190+
assert_raises(ValueError, mutual_info, X_csr, y,
191+
discrete_features=[True, False, True])
192+
assert_raises(IndexError, mutual_info, X, y,
193+
discrete_features=[True, False, True, False])
194+
assert_raises(IndexError, mutual_info, X, y, discrete_features=[1, 4])
188195

189196
mi_1 = mutual_info(X, y, discrete_features='auto', random_state=0)
190197
mi_2 = mutual_info(X, y, discrete_features=False, random_state=0)
191-
192-
mi_3 = mutual_info(X_csr, y, discrete_features='auto',
193-
random_state=0)
194-
mi_4 = mutual_info(X_csr, y, discrete_features=True,
198+
mi_3 = mutual_info(X_csr, y, discrete_features='auto', random_state=0)
199+
mi_4 = mutual_info(X_csr, y, discrete_features=True, random_state=0)
200+
mi_5 = mutual_info(X, y, discrete_features=[True, False, True],
195201
random_state=0)
202+
mi_6 = mutual_info(X, y, discrete_features=[0, 2], random_state=0)
196203

197204
assert_array_equal(mi_1, mi_2)
198205
assert_array_equal(mi_3, mi_4)
206+
assert_array_equal(mi_5, mi_6)
199207

200208
assert not np.allclose(mi_1, mi_3)

0 commit comments

Comments
 (0)
0