8000 remove the rest of _object_dtype_isnan usages · adrinjalali/scikit-learn@c6e867e · GitHub
[go: up one dir, main page]

Skip to content

Commit c6e867e

Browse files
committed
remove the rest of _object_dtype_isnan usages
1 parent 6408b73 commit c6e867e

File tree

2 files changed

+1
-18
lines changed
2 files changed
+1
-18
lines changed

sklearn/utils/tests/test_fixes.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from sklearn.utils.fixes import MaskedArray
1616
from sklearn.utils.fixes import _joblib_parallel_args
17-
from sklearn.utils.fixes import _object_dtype_isnan
1817
from sklearn.utils.fixes import loguniform
1918

2019

@@ -58,21 +57,6 @@ def test_joblib_parallel_args(monkeypatch, joblib_version):
5857
raise ValueError
5958

6059

61-
@pytest.mark.parametrize("dtype, val", ([object, 1],
62-
[object, "a"],
63-
[float, 1]))
64-
def test_object_dtype_isnan(dtype, val):
65-
X = np.array([[val, np.nan],
66-
[np.nan, val]], dtype=dtype)
67-
68-
expected_mask = np.array([[False, True],
69-
[True, False]])
70-
71-
mask = _object_dtype_isnan(X)
72-
73-
assert_array_equal(mask, expected_mask)
74-
75-
7660
@pytest.mark.parametrize("low,high,base",
7761
[(-1, 0, 10), (0, 2, np.exp(1)), (-1, 1, 2)])
7862
def test_loguniform(low, high, base):

sklearn/utils/validation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from numpy.core.numeric import ComplexWarning
2222
import joblib
2323

24-
from .fixes import _object_dtype_isnan
2524
from .. import get_config as _get_config
2625
from ..exceptions import NonBLASDotWarning, PositiveSpectrumWarning
2726
from ..exceptions import NotFittedError
@@ -61,7 +60,7 @@ def _assert_all_finite(X, allow_nan=False, msg_dtype=None):
6160
)
6261
# for object dtype data, we only check for NaNs (GH-13254)
6362
elif X.dtype == np.dtype('object') and not allow_nan:
64-
if _object_dtype_isnan(X).any():
63+
if np.isnan(X).any():
6564
raise ValueError("Input contains NaN")
6665

6766

0 commit comments

Comments
 (0)
0