From 7cb8ad48a5202ffa79218b66a331164b9c5716bc Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 7 Sep 2018 19:51:01 +0200 Subject: [PATCH 1/4] BUG: always raise on NaN in OneHotEncoder for object dtype data --- sklearn/preprocessing/_encoders.py | 26 +++++++++----- sklearn/preprocessing/tests/test_encoders.py | 37 ++++++++++++++++++++ 2 files changed, 55 insertions(+), 8 deletions(-) diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index bd6e10fb62810..bb07feee54352 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -14,7 +14,7 @@ from ..externals import six from ..utils import check_array from ..utils import deprecated -from ..utils.fixes import _argmax +from ..utils.fixes import _argmax, _object_dtype_isnan from ..utils.validation import check_is_fitted from .base import _transform_selected @@ -37,14 +37,29 @@ class _BaseEncoder(BaseEstimator, TransformerMixin): """ - def _fit(self, X, handle_unknown='error'): + def _check_X(self, X): + """ + Perform custom check_array: + - convert list of strings to object dtype + - check for missing values for object dtype data (check_array does + not do that) + """ X_temp = check_array(X, dtype=None) if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, np.str_): X = check_array(X, dtype=np.object) else: X = X_temp + if X.dtype == np.dtype('object'): + if _object_dtype_isnan(X).any(): + raise ValueError("Input contains NaN") + + return X + + def _fit(self, X, handle_unknown='error'): + X = self._check_X(X) + n_samples, n_features = X.shape if self._categories != 'auto': @@ -74,12 +89,7 @@ def _fit(self, X, handle_unknown='error'): self.categories_.append(cats) def _transform(self, X, handle_unknown='error'): - - X_temp = check_array(X, dtype=None) - if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, np.str_): - X = check_array(X, dtype=np.object) - else: - X = X_temp + X = self._check_X(X) _, n_features = X.shape X_int = np.zeros_like(X, dtype=np.int) diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index 9ec16b85df60d..a9b0cad598ed6 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -497,6 +497,25 @@ def test_one_hot_encoder_feature_names_unicode(): assert_array_equal([u'n👍me_c❤t1', u'n👍me_dat2'], feature_names) +@pytest.mark.parametrize("X", [np.array([[1, np.nan]]).T, + np.array([['a', np.nan]], dtype=object).T], + ids=['numeric', 'object']) +@pytest.mark.parametrize("handle_unknown", ['error', 'ignore']) +def test_one_hot_encoder_raise_missing(X, handle_unknown): + ohe = OneHotEncoder(categories='auto', handle_unknown=handle_unknown) + + with pytest.raises(ValueError): + ohe.fit(X) + + with pytest.raises(ValueError): + ohe.fit_transform(X) + + ohe.fit(X[:1, :]) + + with pytest.raises(ValueError): + ohe.transform(X) + + @pytest.mark.parametrize("X", [ [['abc', 2, 55], ['def', 1, 55]], np.array([[10, 2, 55], [20, 1, 55]]), @@ -524,6 +543,24 @@ def test_ordinal_encoder_inverse(): assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr) +@pytest.mark.parametrize("X", [np.array([[1, np.nan]]).T, + np.array([['a', np.nan]], dtype=object).T], + ids=['numeric', 'object']) +def test_ordinal_encoder_raise_missing(X): + ohe = OrdinalEncoder() + + with pytest.raises(ValueError): + ohe.fit(X) + + with pytest.raises(ValueError): + ohe.fit_transform(X) + + ohe.fit(X[:1, :]) + + with pytest.raises(ValueError): + ohe.transform(X) + + def test_encoder_dtypes(): # check that dtypes are preserved when determining categories enc = OneHotEncoder(categories='auto') From 770204c63d978ae093c355004de26f2e284066dd Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Fri, 7 Sep 2018 20:23:45 +0200 Subject: [PATCH 2/4] pep8 --- sklearn/preprocessing/_encoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index bb07feee54352..fee0ae50ad8c3 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -52,8 +52,8 @@ def _check_X(self, X): X = X_temp if X.dtype == np.dtype('object'): - if _object_dtype_isnan(X).any(): - raise ValueError("Input contains NaN") + if _object_dtype_isnan(X).any(): + raise ValueError("Input contains NaN") return X From 891f24440ad834fb3968876f446bb74ea0903725 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 12 Sep 2018 10:41:58 +0200 Subject: [PATCH 3/4] add check for assume_finite config --- sklearn/preprocessing/_encoders.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/_encoders.py b/sklearn/preprocessing/_encoders.py index fee0ae50ad8c3..10324e17061e8 100644 --- a/sklearn/preprocessing/_encoders.py +++ b/sklearn/preprocessing/_encoders.py @@ -10,6 +10,7 @@ import numpy as np from scipy import sparse +from .. import get_config as _get_config from ..base import BaseEstimator, TransformerMixin from ..externals import six from ..utils import check_array @@ -52,8 +53,9 @@ def _check_X(self, X): X = X_temp if X.dtype == np.dtype('object'): - if _object_dtype_isnan(X).any(): - raise ValueError("Input contains NaN") + if not _get_config()['assume_finite']: + if _object_dtype_isnan(X).any(): + raise ValueError("Input contains NaN") return X From 210156625a13fb3a246f6d448333f9ce8e5a8477 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 12 Sep 2018 10:43:25 +0200 Subject: [PATCH 4/4] add match string to error assert --- sklearn/preprocessing/tests/test_encoders.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sklearn/preprocessing/tests/test_encoders.py b/sklearn/preprocessing/tests/test_encoders.py index a9b0cad598ed6..67169432defdc 100644 --- a/sklearn/preprocessing/tests/test_encoders.py +++ b/sklearn/preprocessing/tests/test_encoders.py @@ -504,15 +504,15 @@ def test_one_hot_encoder_feature_names_unicode(): def test_one_hot_encoder_raise_missing(X, handle_unknown): ohe = OneHotEncoder(categories='auto', handle_unknown=handle_unknown) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Input contains NaN"): ohe.fit(X) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Input contains NaN"): ohe.fit_transform(X) ohe.fit(X[:1, :]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Input contains NaN"): ohe.transform(X) @@ -549,15 +549,15 @@ def test_ordinal_encoder_inverse(): def test_ordinal_encoder_raise_missing(X): ohe = OrdinalEncoder() - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Input contains NaN"): ohe.fit(X) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Input contains NaN"): ohe.fit_transform(X) ohe.fit(X[:1, :]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="Input contains NaN"): ohe.transform(X)