10000 BUG always raise on NaN in OneHotEncoder for object dtype data (#12033) · scikit-learn/scikit-learn@dfdf605 · GitHub
[go: up one dir, main page]

Skip to content

Commit dfdf605

Browse files
jorisvandenbosscherth
authored andcommitted
BUG always raise on NaN in OneHotEncoder for object dtype data (#12033)
1 parent ec69171 commit dfdf605

File tree

2 files changed

+57
-8
lines changed

2 files changed

+57
-8
lines changed

sklearn/preprocessing/_encoders.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import numpy as np
1111
from scipy import sparse
1212

13+
from .. import get_config as _get_config
1314
from ..base import BaseEstimator, TransformerMixin
1415
from ..externals import six
1516
from ..utils import check_array
1617
from ..utils import deprecated
17-
from ..utils.fixes import _argmax
18+
from ..utils.fixes import _argmax, _object_dtype_isnan
1819
from ..utils.validation import check_is_fitted
1920

2021
from .base import _transform_selected
@@ -37,14 +38,30 @@ class _BaseEncoder(BaseEstimator, TransformerMixin):
3738
3839
"""
3940

40-
def _fit(self, X, handle_unknown='error'):
41+
def _check_X(self, X):
42+
"""
43+
Perform custom check_array:
44+
- convert list of strings to object dtype
45+
- check for missing values for object dtype data (check_array does
46+
not do that)
4147
48+
"""
4249
X_temp = check_array(X, dtype=None)
4350
if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, np.str_):
4451
X = check_array(X, dtype=np.object)
4552
else:
4653
X = X_temp
4754

55+
if X.dtype == np.dtype('object'):
56+
if not _get_config()['assume_finite']:
57+
if _object_dtype_isnan(X).any():
58+
raise ValueError("Input contains NaN")
59+
60+
return X
61+
62+
def _fit(self, X, handle_unknown='error'):
63+
X = self._check_X(X)
64+
4865
n_samples, n_features = X.shape
4966

5067
if self._categories != 'auto':
@@ -74,12 +91,7 @@ def _fit(self, X, handle_unknown='error'):
7491
self.categories_.append(cats)
7592

7693
def _transform(self, X, handle_unknown='error'):
77-
78-
X_temp = check_array(X, dtype=None)
79-
if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, np.str_):
80-
X = check_array(X, dtype=np.object)
81-
else:
82-
X = X_temp
94+
X = self._check_X(X)
8395

8496
_, n_features = X.shape
8597
X_int = np.zeros_like(X, dtype=np.int)

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,25 @@ def test_one_hot_encoder_feature_names_unicode():
497497
assert_array_equal([u'n👍me_c❤t1', u'n👍me_dat2'], feature_names)
498498

499499

500+
@pytest.mark.parametrize("X", [np.array([[1, np.nan]]).T,
501+
np.array([['a', np.nan]], dtype=object).T],
502+
ids=['numeric', 'object'])
503+
@pytest.mark.parametrize("handle_unknown", ['error', 'ignore'])
504+
def test_one_hot_encoder_raise_missing(X, handle_unknown):
505+
ohe = OneHotEncoder(categories='auto', handle_unknown=handle_unknown)
506+
507+
with pytest.raises(ValueError, match="Input contains NaN"):
508+
ohe.fit(X)
509+
510+
with pytest.raises(ValueError, match="Input contains NaN"):
511+
ohe.fit_transform(X)
512+
513+
ohe.fit(X[:1, :])
514+
515+
with pytest.raises(ValueError, match="Input contains NaN"):
516+
ohe.transform(X)
517+
518+
500519
@pytest.mark.parametrize("X", [
501520
[['abc', 2, 55], ['def', 1, 55]],
502521
np.array([[10, 2, 55], [20, 1, 55]]),
@@ -524,6 +543,24 @@ def test_ordinal_encoder_inverse():
524543
assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr)
525544

526545

546+
@pytest.mark.parametrize("X", [np.array([[1, np.nan]]).T,
547+
np.array([['a', np.nan]], dtype=object).T],
548+
ids=['numeric', 'object'])
549+
def test_ordinal_encoder_raise_missing(X):
550+
ohe = OrdinalEncoder()
551+
552+
with pytest.raises(ValueError, match="Input contains NaN"):
553+
ohe.fit(X)
554+
555+
with pytest.raises(ValueError, match="Input contains NaN"):
556+
ohe.fit_transform(X)
557+
558+
ohe.fit(X[:1, :])
559+
560+
with pytest.raises(ValueError, match="Input contains NaN"):
561+
ohe.transform(X)
562+
563+
527564
def test_encoder_dtypes():
528565
# check that dtypes are preserved when determining categories
529566
enc = OneHotEncoder(categories='auto')

0 commit comments

Comments
 (0)
0