8000 Revert "FIX remove FutureWarning in _object_dtype_isnan and add test … · xhluca/scikit-learn@633f907 · GitHub
[go: up one dir, main page]

Skip to content

Commit 633f907

Browse files
author
Xing
authored
Revert "FIX remove FutureWarning in _object_dtype_isnan and add test (scikit-learn#12567)"
This reverts commit 178f87c.
1 parent b22ae86 commit 633f907

File tree

2 files changed

+9
-20
lines changed

2 files changed

+9
-20
lines changed

sklearn/utils/fixes.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,13 +309,18 @@ def nanmedian(a, axis=None):
309309
# Fix for behavior inconsistency on numpy.equal for object dtypes.
310310
# For numpy versions < 1.13, numpy.equal tests element-wise identity of objects
311311
# instead of equality. This fix returns the mask of NaNs in an array of
312-
# numerical or object values for all numpy versions.
313-
if np_version < (1, 13):
312+
# numerical or object values for all nupy versions.
313+
314+
_nan_object_array = np.array([np.nan], dtype=object)
315+
_nan_object_mask = _nan_object_array != _nan_object_array
316+
317+
if np.array_equal(_nan_object_mask, np.array([True])):
314318
def _object_dtype_isnan(X):
315-
return np.frompyfunc(lambda x: x != x, 1, 1)(X).astype(bool)
319+
return X != X
320+
316321
else:
317322
def _object_dtype_isnan(X):
318-
return X != X
323+
return np.frompyfunc(lambda x: x != x, 1, 1)(X).astype(bool)
319324

320325

321326
# To be removed once this fix is included in six

sklearn/utils/tests/test_fixes.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from sklearn.utils.fixes import nanmedian
1818
from sklearn.utils.fixes import nanpercentile
1919
from sklearn.utils.fixes import _joblib_parallel_args
20-
from sklearn.utils.fixes import _object_dtype_isnan
2120

2221

2322
def test_divide():
@@ -89,18 +88,3 @@ def test_joblib_parallel_args(monkeypatch, joblib_version):
8988
_joblib_parallel_args(verbose=True)
9089
else:
9190
raise ValueError
92-
93-
94-
@pytest.mark.parametrize("dtype, val", ([object, 1],
95-
[object, "a"],
96-
[float, 1]))
97-
def test_object_dtype_isnan(dtype, val):
98-
X = np.array([[val, np.nan],
99-
[np.nan, val]], dtype=dtype)
100-
101-
expected_mask = np.array([[False, True],
102-
[True, False]])
103-
104-
mask = _object_dtype_isnan(X)
105-
106-
assert_array_equal(mask, expected_mask)

0 commit comments

Comments
 (0)
0