8000 FIX remove FutureWarning in _object_dtype_isnan and add test (#12567) · xhluca/scikit-learn@178f87c · GitHub
[go: up one dir, main page]

Skip to content

Commit 178f87c

Browse files
jeremiedbbXing
authored andcommitted
FIX remove FutureWarning in _object_dtype_isnan and add test (scikit-learn#12567)
1 parent 35bdb21 commit 178f87c

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

sklearn/utils/fixes.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -309,18 +309,13 @@ def nanmedian(a, axis=None):
309309
# Fix for behavior inconsistency on numpy.equal for object dtypes.
310310
# For numpy versions < 1.13, numpy.equal tests element-wise identity of objects
311311
# instead of equality. This fix returns the mask of NaNs in an array of
312-
# numerical or object values for all nupy versions.
313-
314-
_nan_object_array = np.array([np.nan], dtype=object)
315-
_nan_object_mask = _nan_object_array != _nan_object_array
316-
317-
if np.array_equal(_nan_object_mask, np.array([True])):
312+
# numerical or object values for all numpy versions.
313+
if np_version < (1, 13):
318314
def _object_dtype_isnan(X):
319-
return X != X
320-
< 10000 /div>
315+
return np.frompyfunc(lambda x: x != x, 1, 1)(X).astype(bool)
321316
else:
322317
def _object_dtype_isnan(X):
323-
return np.frompyfunc(lambda x: x != x, 1, 1)(X).astype(bool)
318+
return X != X
324319

325320

326321
# To be removed once this fix is included in six

sklearn/utils/tests/test_fixes.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.utils.fixes import nanmedian
1818
from sklearn.utils.fixes import nanpercentile
1919
from sklearn.utils.fixes import _joblib_parallel_args
20+
from sklearn.utils.fixes import _object_dtype_isnan
2021

2122

2223
def test_divide():
@@ -88,3 +89,18 @@ def test_joblib_parallel_args(monkeypatch, joblib_version):
8889
_joblib_parallel_args(verbose=True)
8990
else:
9091
raise ValueError
92+
93+
94+
@pytest.mark.parametrize("dtype, val", ([object, 1],
95+
[object, "a"],
96+
[float, 1]))
97+
def test_object_dtype_isnan(dtype, val):
98+
X = np.array([[val, np.nan],
99+
[np.nan, val]], dtype=dtype)
100+
101+
expected_mask = np.array([[False, True],
102+
[True, False]])
103+
104+
mask = _object_dtype_isnan(X)
105+
106+
assert_array_equal(mask, expected_mask)

0 commit comments

Comments
 (0)
0