8000 ENH Adds missing value support to OneHotEncoder (#17317) · thomasjpfan/scikit-learn@eeddc17 · GitHub
[go: up one dir, main page]

Skip to content

Commit eeddc17

Browse files
authored
ENH Adds missing value support to OneHotEncoder (scikit-learn#17317)
OneHotEncoder supports categorical features with missing values by considering the missing values as an additional category.
1 parent 3b334c5 commit eeddc17

File tree

8 files changed

+499
-109
lines changed

8 files changed

+499
-109
lines changed

doc/modules/preprocessing.rst

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -590,6 +590,34 @@ In the transformed `X`, the first column is the encoding of the feature with
590590
categories "male"/"female", while the remaining 6 columns is the encoding of
591591
the 2 features with respectively 3 categories each.
592592

593+
:class:`OneHotEncoder` supports categorical features with missing values by
594+
considering the missing values as an additional category::
595+
596+
>>> X = [['male', 'Safari'],
597+
... ['female', None],
598+
... [np.nan, 'Firefox']]
599+
>>> enc = preprocessing.OneHotEncoder(handle_unknown='error').fit(X)
600+
>>> enc.categories_
601+
[array(['female', 'male', nan], dtype=object),
602+
array(['Firefox', 'Safari', None], dtype=object)]
603+
>>> enc.transform(X).toarray()
604+
array([[0., 1., 0., 0., 1., 0.],
605+
[1., 0., 0., 0., 0., 1.],
606+
[0., 0., 1., 1., 0., 0.]])
607+
608+
If a feature contains both `np.nan` and `None`, they will be considered
609+
separate categories::
610+
611+
>>> X = [['Safari'], [None], [np.nan], ['Firefox']]
612+
>>> enc = preprocessing.OneHotEncoder(handle_unknown='error').fit(X)
613+
>>> enc.categories_
614+
[array(['Firefox', 'Safari', None, nan], dtype=object)]
615+
>>> enc.transform(X).toarray()
616+
array([[0., 1., 0., 0.],
617+
[0., 0., 1., 0.],
618+
[0., 0., 0., 1.],
619+
[1., 0., 0., 0.]])
620+
593621
See :ref:`dict_feature_extraction` for categorical features that are
594622
represented as a dict, not as scalars.
595623

@@ -791,5 +819,5 @@ error with a ``filterwarnings``::
791819
... category=UserWarning, append=False)
792820

793821
For a full code example that demonstrates using a :class:`FunctionTransformer`
794-
to extract features from text data see
822+
to extract features from text data see
795823
:ref:`sphx_glr_auto_examples_compose_plot_column_transformer.py`

doc/whats_new/v0.24.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,9 @@ Changelog
552552
:mod:`sklearn.preprocessing`
553553
............................
554554

555+
- |Feature| :class:`preprocessing.OneHotEncoder` now supports missing
556+
values by treating them as a category. :pr:`17317` by `Thomas Fan`_.
557+
555558
- |Feature| Add a new ``handle_unknown`` parameter with a
556559
``use_encoded_value`` option, along with a new ``unknown_value`` parameter,
557560
to :class:`preprocessing.OrdinalEncoder` to allow unknown categories during

examples/compose/plot_column_transformer_mixed_types.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,7 @@
7272
('scaler', StandardScaler())])
7373

7474
categorical_features = ['embarked', 'sex', 'pclass']
75-
categorical_transformer = Pipeline(steps=[
76-
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
77-
('onehot', OneHotEncoder(handle_unknown='ignore'))])
75+
categorical_transformer = OneHotEncoder(handle_unknown='ignore')
7876

7977
preprocessor = ColumnTransformer(
9E88
8078
transformers=[

examples/inspection/plot_permutation_importance.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,13 @@
6363
X_train, X_test, y_train, y_test = train_test_split(
6464
X, y, stratify=y, random_state=42)
6565

66-
categorical_pipe = Pipeline([
67-
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
68-
('onehot', OneHotEncoder(handle_unknown='ignore'))
69-
])
66+
categorical_encoder = OneHotEncoder(handle_unknown='ignore')
7067
numerical_pipe = Pipeline([
7168
('imputer', SimpleImputer(strategy='mean'))
7269
])
7370

7471
preprocessing = ColumnTransformer(
75-
[('cat', categorical_pipe, categorical_columns),
72+
[('cat', categorical_encoder, categorical_columns),
7673
('num', numerical_pipe, numerical_columns)])
7774

7875
rf = Pipeline([
@@ -122,8 +119,7 @@
122119
# predictions that generalize to the test set (when the model has enough
123120
# capacity).
124121
ohe = (rf.named_steps['preprocess']
125-
.named_transformers_['cat']
126-
.named_steps['onehot'])
122+
.named_transformers_['cat'])
127123
feature_names = ohe.get_feature_names(input_features=categorical_columns)
128124
feature_names = np.r_[feature_names, numerical_columns]
129125

sklearn/preprocessing/_encoders.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class _BaseEncoder(TransformerMixin, BaseEstimator):
2727
2828
"""
2929

30-
def _check_X(self, X):
30+
def _check_X(self, X, force_all_finite=True):
3131
"""
3232
Perform custom check_array:
3333
- convert list of strings to object dtype
@@ -41,17 +41,19 @@ def _check_X(self, X):
4141
"""
4242
if not (hasattr(X, 'iloc') and getattr(X, 'ndim', 0) == 2):
4343
# if not a dataframe, do normal check_array validation
44-
X_temp = check_array(X, dtype=None)
44+
X_temp = check_array(X, dtype=None,
45+
force_all_finite=force_all_finite)
4546
if (not hasattr(X, 'dtype')
4647
and np.issubdtype(X_temp.dtype, np.str_)):
47-
X = check_array(X, dtype=object)
48+
X = check_array(X, dtype=object,
49+
force_all_finite=force_all_finite)
4850
else:
4951
X = X_temp
5052
needs_validation = False
5153
else:
5254
# pandas dataframe, do validation later column by column, in order
5355
# to keep the dtype information to be used in the encoder.
54-
needs_validation = True
56+
needs_validation = force_all_finite
5557

5658
n_samples, n_features = X.shape
5759
X_columns = []
@@ -71,8 +73,9 @@ def _get_feature(self, X, feature_idx):
7173
# numpy arrays, sparse arrays
7274
return X[:, feature_idx]
7375

74-
def _fit(self, X, handle_unknown='error'):
75-
X_list, n_samples, n_features = self._check_X(X)
76+
def _fit(self, X, handle_unknown='error', force_all_finite=True):
77+
X_list, n_samples, n_features = self._check_X(
78+
X, force_all_finite=force_all_finite)
7679

7780
if self.categories != 'auto':
7881
if len(self.categories) != n_features:
@@ -88,9 +91,16 @@ def _fit(self, X, handle_unknown='error'):
8891
else:
8992
cats = np.array(self.categories[i], dtype=Xi.dtype)
9093
if Xi.dtype != object:
91-
if not np.all(np.sort(cats) == cats):
92-
raise ValueError("Unsorted categories are not "
93-
"supported for numerical categories")
94+
sorted_cats = np.sort(cats)
95+
error_msg = ("Unsorted categories are not "
96+
"supported for numerical categories")
97+
# if there are nans, nan should be the last element
98+
stop_idx = -1 if np.isnan(sorted_cats[-1]) else None
99+
if (np.any(sorted_cats[:stop_idx] != cats[:stop_idx]) or
100+
(np.isnan(sorted_cats[-1]) and
101+
not np.isnan(sorted_cats[-1]))):
102+
raise ValueError(error_msg)
103+
94104
if handle_unknown == 'error':
95105
diff = _check_unknown(Xi, cats)
96106
if diff:
@@ -99,8 +109,9 @@ def _fit(self, X, handle_unknown='error'):
99109
raise ValueError(msg)
100110
self.categories_.append(cats)
101111

102-
def _transform(self, X, handle_unknown='error'):
103-
X_list, n_samples, n_features = self._check_X(X)
112+
def _transform(self, X, handle_unknown='error', force_all_finite=True):
113+
X_list, n_samples, n_features = self._check_X(
114+
X, force_all_finite=force_all_finite)
104115

105116
X_int = np.zeros((n_samples, n_features), dtype=int)
106117
X_mask = np.ones((n_samples, n_features), dtype=bool)
@@ -355,8 +366,26 @@ def _compute_drop_idx(self):
355366
"of features ({}), got {}")
356367
raise ValueError(msg.format(len(self.categories_),
357368
len(self.drop)))
358-
missing_drops = [(i, val) for i, val in enumerate(self.drop)
359-
if val not in self.categories_[i]]
369+
missing_drops = []
370+
drop_indices = []
371+
for col_idx, (val, cat_list) in enumerate(zip(self.drop,
372+
self.categories_)):
373+
if not is_scalar_nan(val):
374+
drop_idx = np.where(cat_list == val)[0]
375+
if drop_idx.size: # found drop idx
376+
drop_indices.append(drop_idx[0])
377+
else:
378+
missing_drops.append((col_idx, val))
379+
continue
380+
381+
# val is nan, find nan in categories manually
382+
for cat_idx, cat in enumerate(cat_list):
383+
if is_scalar_nan(cat):
384+
drop_indices.append(cat_idx)
385+
break
386+
else: # loop did not break thus drop is missing
387+
missing_drops.append((col_idx, val))
388+
360389
if any(missing_drops):
361390
msg = ("The following categories were supposed to be "
362391
"dropped, but were not found in the training "
@@ -365,10 +394,7 @@ def _compute_drop_idx(self):
365394
["Category: {}, Feature: {}".format(c, v)
366395
for c, v in missing_drops])))
367396
raise ValueError(msg)
368-
return np.array([np.where(cat_list == val)[0][0]
369-
for (val, cat_list) in
370-
zip(self.drop, self.categories_)],
371-
dtype=object)
397+
return np.array(drop_indices, dtype=object)
372398

373399
def fit(self, X, y=None):
374400
"""
@@ -388,7 +414,8 @@ def fit(self, X, y=None):
388414
self
389415
"""
390416
self._validate_keywords()
391-
self._fit(X, handle_unknown=self.handle_unknown)
417+
self._fit(X, handle_unknown=self.handle_unknown,
418+
force_all_finite='allow-nan')
392419
self.drop_idx_ = self._compute_drop_idx()
393420
return self
394421

@@ -431,7 +458,8 @@ def transform(self, X):
431458
"""
432459
check_is_fitted(self)
433460
# validation of X happens in _check_X called by _transform
434-
X_int, X_mask = self._transform(X, handle_unknown=self.handle_unknown)
461+
X_int, X_mask = self._transform(X, handle_unknown=self.handle_unknown,
462+
force_all_finite='allow-nan')
435463

436464
n_samples, n_features = X_int.shape
437465

0 commit comments

Comments
 (0)
0