8000 handle_missing = 'indicator' · scikit-learn/scikit-learn@8ab4ec8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8ab4ec8

Browse files
author
Katrina Ni
committed
handle_missing = 'indicator'
1 parent ac0d240 commit 8ab4ec8

File tree

3 files changed

+48
-40
lines changed

3 files changed

+48
-40
lines changed

sklearn/preprocessing/_encoders.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,15 @@
22
# Joris Van den Bossche <jorisvandenbossche@gmail.com>
33
# License: BSD 3 clause
44

5-
import warnings
6-
75
import numpy as np
86
from scipy import sparse
97

108
from ..base import BaseEstimator, TransformerMixin
119
from ..utils import check_array
12-
from ..utils.validation import check_is_fitted
1310
from ..utils.fixes import _object_dtype_isnan
11+
from ..utils.validation import check_is_fitted
1412
from ._label import _encode, _encode_check_unknown
1513

16-
1714
__all__ = [
1815
'OneHotEncoder',
1916
'OrdinalEncoder'
@@ -93,19 +90,20 @@ def _fit(self, X, handle_unknown='error'):
9390
"supported for numerical categories")
9491
if handle_unknown == 'error':
9592
# NaNs don't count as categoreis during fit
96-
diff = _encode_check_unknown(Xi[~_object_dtype_isnan(Xi)], cats)
93+
diff = _encode_check_unknown(
94+
Xi[~_object_dtype_isnan(Xi)], cats)
9795
if diff:
9896
msg = ("Found unknown categories {0} in column {1}"
9997
" during fit".format(diff, i))
10098
raise ValueError(msg)
10199
self.categories_.append(cats)
102100

103101
def _transform(self, X, handle_unknown='error', handle_missing=None):
104-
X_list, n_samples, n_features = self._check_X(
105-
X)
106-
# from now on, either X is w.o. NaNs or w. NaNs yet handle_missing != None.
107-
# in the later case, since we'll handle NaNs separately,
108-
# NaNs don't count as unknown categories
102+
X_list, n_samples, n_features = self._check_X(X)
103+
# from now on, either X is w.o. NaNs
104+
# or w. NaNs yet handle_missing != None.
105+
# since we'll handle NaNs separately so that it does not intefere
106+
# with handle_unknown, we won't count NaNs as unknown categories
109107
X_int = np.zeros((n_samples, n_features), dtype=np.int)
110108
X_mask = np.ones((n_samples, n_features), dtype=np.bool)
111109

@@ -137,7 +135,7 @@ def _transform(self, X, handle_unknown='error', handle_missing=None):
137135
Xi = Xi.astype(self.categories_[i].dtype)
138136
else:
139137
Xi = Xi.copy()
140-
138+
141139
if handle_missing == 'indicator':
142140
valid_mask = na_valid_mask
143141
Xi[~valid_mask] = self.categories_[i][0]
@@ -151,6 +149,11 @@ def _transform(self, X, handle_unknown='error', handle_missing=None):
151149
check_unknown=False)
152150
X_int[:, i] = encoded
153151

152+
if (self.handle_missing 628C == 'indicator' and
153+
_object_dtype_isnan(Xi).sum() > 0):
154+
self.categories_[i] = np.append(
155+
np.array(self.categories_[i], dtype=object), None)
156+
154157
return X_int, X_mask
155158

156159
def _more_tags(self):
@@ -230,7 +233,8 @@ class OneHotEncoder(_BaseEncoder):
230233
will be denoted as None.
231234
232235
handle_missing : {'indicator', 'all-zero'}, default=None
233-
Specify how to handle missing categorical features (NaN) in the training data
236+
Specify how to handle missing categorical features (NaN) in the
237+
training data
234238
235239
- None : Raise an error in the presence of NaN (the default).
236240
- 'indicator': Represent with a separate one-hot column.
@@ -310,7 +314,8 @@ class OneHotEncoder(_BaseEncoder):
310314
"""
311315

312316
def __init__(self, categories='auto', drop=None, sparse=True,
313-
dtype=np.float64, handle_unknown='error', handle_missing=None):
317+
dtype=np.float64, handle_unknown='error',
318+
handle_missing=None):
314319
self.categories = categories
315320
self.sparse = sparse
316321
self.dtype = dtype
@@ -441,7 +446,8 @@ def transform(self, X):
441446
check_is_fitted(self)
442447
# validation of X happens in _check_X called by _transform
443448
X_int, X_mask = self._transform(
444-
X, handle_unknown=self.handle_unknown, handle_missing=self.handle_missing)
449+
X, handle_unknown=self.handle_unknown,
450+
handle_missing=self.handle_missing)
445451

446452
n_samples, n_features = X_int.shape
447453

@@ -486,7 +492,6 @@ def transform(self, X):
486492
else:
487493
return out
488494

489-
490495
def inverse_transform(self, X):
491496
"""
492497
Convert the data back to the original representation.
@@ -549,7 +554,8 @@ def inverse_transform(self, X):
549554
# for sparse X argmax returns 2D matrix, ensure 1D array
550555
labels = np.asarray(sub.argmax(axis=1)).flatten()
551556
X_tr[:, i] = cats[labels]
552-
if self.handle_unknown == 'ignore' or self.handle_missing == 'all-zero':
557+
if (self.handle_unknown == 'ignore' or
558+
self.handle_missing == 'all-zero'):
553559
unknown = np.asarray(sub.sum(axis=1) == 0).flatten()
554560
# ignored unknown categories: we have a row of all zero
555561
if unknown.any():

sklearn/preprocessing/_label.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,11 @@ def _encode_python(values, uniques=None, encode=False, check_unknown=True):
6767
encoded = np.array([table[v] for v in values])
6868
except KeyError as e:
6969
raise ValueError("y contains previously unseen labels: %s"
70-
% str(e))
70+
% str(e))
7171
else:
72-
encoded = np.array([table[v] if v in table else n_uniques for v in values ])
73-
72+
encoded = np.array(
73+
[table[v] if v in table else n_uniques for v in values])
74+
7475
return uniques, encoded
7576
else:
7677
return uniques
@@ -114,7 +115,8 @@ def _encode(values, uniques=None, encode=False, check_unknown=True):
114115
"""
115116
if values.dtype == object:
116117
try:
117-
res = _encode_python(values, uniques, encode, check_unknown=check_unknown)
118+
res = _encode_python(values, uniques, encode,
119+
check_unknown=check_unknown)
118120
except TypeError:
119121
types = sorted(t.__qualname__
120122
for t in set(type(v) for v in values))

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,8 @@ def test_one_hot_encoder_raise_missing(X, as_data_frame, handle_unknown):
462462

463463
ohe = OneHotEncoder(categories='auto', handle_unknown=handle_unknown)
464464

465-
# with pytest.raises(ValueError, match="Input contains NaN"):
466-
# ohe.fit(X)
465+
with pytest.raises(ValueError, match="Input contains NaN"):
466+
ohe.fit(X)
467467

468468
with pytest.raises(ValueError, match="Input contains NaN"):
469469
ohe.fit_transform(X)
@@ -490,14 +490,17 @@ def test_one_hot_encoder_handle_missing(X, as_data_frame, handle_unknown):
490490
pd = pytest.importorskip('pandas')
491491
X = pd.DataFrame(X)
492492

493-
# enc_ind = OneHotEncoder(
494-
# categories='auto', sparse=False, handle_missing='indicator')
495-
# exp_ind = np.array([[1, 0, 0],
496-
# [0, 1, 0],
497-
# [0, 0, 1],
498-
# [0, 1, 0]], dtype='int64')
499-
# print(enc_ind.fit_transform(X))
500-
# assert_array_equal(enc_ind.fit_transform(X), exp_ind.astype('float64'))
493+
X_inv = np.array(X, dtype=object)
494+
X_inv[2, 0] = None
495+
496+
enc_ind = OneHotEncoder(
497+
categories='auto', sparse=False, handle_missing='indicator')
498+
exp_ind = np.array([[1, 0, 0],
499+
[0, 1, 0],
500+
[0, 0, 1],
501+
[0, 1, 0]], dtype='int64')
502+
assert_array_equal(enc_ind.fit_transform(X), exp_ind.astype('float64'))
503+
assert_array_equal(enc_ind.inverse_transform(exp_ind), X_inv)
501504

502505
enc_zero = OneHotEncoder(
503506
categories='auto', sparse=False, handle_missing='all-zero')
@@ -506,9 +509,6 @@ def test_one_hot_encoder_handle_missing(X, as_data_frame, handle_unknown):
506509
[0, 0],
507510
[0, 1]], dtype='int64')
508511
assert_array_equal(enc_zero.fit_transform(X), exp_zero.astype('float64'))
509-
510-
X_inv = np.array(X, dtype=object)
511-
X_inv[2, 0] = None
512512
assert_array_equal(enc_zero.inverse_transform(exp_zero), X_inv)
513513

514514

@@ -574,8 +574,8 @@ def test_ordinal_encoder_inverse():
574574
def test_ordinal_encoder_raise_missing(X):
575575
ohe = OrdinalEncoder()
576576

577-
# with pytest.raises(ValueError, match="Input contains NaN"):
578-
# ohe.fit(X)
577+
with pytest.raises(ValueError, match="Input contains NaN"):
578+
ohe.fit(X)
579579

580580
with pytest.raises(ValueError, match="Input contains NaN"):
581581
ohe.fit_transform(X)
@@ -670,15 +670,15 @@ def test_one_hot_encoder_drop_manual():
670670
@pytest.mark.parametrize(
671671
"X_fit, params, err_msg",
672672
[([["Male"], ["Female"]], {'drop': 'second'},
673-
"Wrong input for parameter `drop`"),
673+
"Wrong input for parameter `drop`"),
674674
([["Male"], ["Female"]], {'drop': 'first', 'handle_unknown': 'ignore'},
675-
"`handle_unknown` must be 'error'"),
675+
"`handle_unknown` must be 'error'"),
676676
([['abc', 2, 55], ['def', 1, 55], ['def', 3, 59]],
677677
{'drop': np.asarray('b', dtype=object)},
678-
"Wrong input for parameter `drop`"),
678+
"Wrong input for parameter `drop`"),
679679
([['abc', 2, 55], ['def', 1, 55], ['def', 3, 59]],
680680
{'drop': ['ghi', 3, 59]},
681-
"The following categories were supposed")]
681+
"The following categories were supposed")]
682682
)
683683
def test_one_hot_encoder_invalid_params(X_fit, params, err_msg):
684684
enc = OneHotEncoder(**params)
@@ -728,4 +728,4 @@ def test_encoders_does_not_support_none_values(Encoder):
728728
values = [["a"], [None]]
729729
with pytest.raises(TypeError, match="Encoders require their input to be "
730730
"uniformly strings or numbers."):
731-
Encoder().fit_transform(values)
731+
Encoder().fit(values)

0 commit comments

Comments
 (0)
0