@@ -27,7 +27,7 @@ class _BaseEncoder(TransformerMixin, BaseEstimator):
27
27
28
28
"""
29
29
30
- def _check_X (self , X ):
30
+ def _check_X (self , X , force_all_finite = True ):
31
31
"""
32
32
Perform custom check_array:
33
33
- convert list of strings to object dtype
@@ -41,17 +41,19 @@ def _check_X(self, X):
41
41
"""
42
42
if not (hasattr (X , 'iloc' ) and getattr (X , 'ndim' , 0 ) == 2 ):
43
43
# if not a dataframe, do normal check_array validation
44
- X_temp = check_array (X , dtype = None )
44
+ X_temp = check_array (X , dtype = None ,
45
+ force_all_finite = force_all_finite )
45
46
if (not hasattr (X , 'dtype' )
46
47
and np .issubdtype (X_temp .dtype , np .str_ )):
47
- X = check_array (X , dtype = object )
48
+ X = check_array (X , dtype = object ,
49
+ force_all_finite = force_all_finite )
48
50
else :
49
51
X = X_temp
50
52
needs_validation = False
51
53
else :
52
54
# pandas dataframe, do validation later column by column, in order
53
55
# to keep the dtype information to be used in the encoder.
54
- needs_validation = True
56
+ needs_validation = force_all_finite
55
57
56
58
n_samples , n_features = X .shape
57
59
X_columns = []
@@ -71,8 +73,9 @@ def _get_feature(self, X, feature_idx):
71
73
# numpy arrays, sparse arrays
72
74
return X [:, feature_idx ]
73
75
74
- def _fit (self , X , handle_unknown = 'error' ):
75
- X_list , n_samples , n_features = self ._check_X (X )
76
+ def _fit (self , X , handle_unknown = 'error' , force_all_finite = True ):
77
+ X_list , n_samples , n_features = self ._check_X (
78
+ X , force_all_finite = force_all_finite )
76
79
77
80
if self .categories != 'auto' :
78
81
if len (self .categories ) != n_features :
@@ -88,9 +91,16 @@ def _fit(self, X, handle_unknown='error'):
88
91
else :
89
92
cats = np .array (self .categories [i ], dtype = Xi .dtype )
90
93
if Xi .dtype != object :
91
- if not np .all (np .sort (cats ) == cats ):
92
- raise ValueError ("Unsorted categories are not "
93
- "supported for numerical categories" )
94
+ sorted_cats = np .sort (cats )
95
+ error_msg = ("Unsorted categories are not "
96
+ "supported for numerical categories" )
97
+ # if there are nans, nan should be the last element
98
+ stop_idx = - 1 if np .isnan (sorted_cats [- 1 ]) else None
99
+ if (np .any (sorted_cats [:stop_idx ] != cats [:stop_idx ]) or
100
+ (np .isnan (sorted_cats [- 1 ]) and
101
+ not np .isnan (sorted_cats [- 1 ]))):
102
+ raise ValueError (error_msg )
103
+
94
104
if handle_unknown == 'error' :
95
105
diff = _check_unknown (Xi , cats )
96
106
if diff :
@@ -99,8 +109,9 @@ def _fit(self, X, handle_unknown='error'):
99
109
raise ValueError (msg )
100
110
self .categories_ .append (cats )
101
111
102
- def _transform (self , X , handle_unknown = 'error' ):
103
- X_list , n_samples , n_features = self ._check_X (X )
112
+ def _transform (self , X , handle_unknown = 'error' , force_all_finite = True ):
113
+ X_list , n_samples , n_features = self ._check_X (
114
+ X , force_all_finite = force_all_finite )
104
115
105
116
X_int = np .zeros ((n_samples , n_features ), dtype = int )
106
117
X_mask = np .ones ((n_samples , n_features ), dtype = bool )
@@ -355,8 +366,26 @@ def _compute_drop_idx(self):
355
366
"of features ({}), got {}" )
356
367
raise ValueError (msg .format (len (self .categories_ ),
357
368
len (self .drop )))
358
- missing_drops = [(i , val ) for i , val in enumerate (self .drop )
359
- if val not in self .categories_ [i ]]
369
+ missing_drops = []
370
+ drop_indices = []
371
+ for col_idx , (val , cat_list ) in enumerate (zip (self .drop ,
372
+ self .categories_ )):
373
+ if not is_scalar_nan (val ):
374
+ drop_idx = np .where (cat_list == val )[0 ]
375
+ if drop_idx .size : # found drop idx
376
+ drop_indices .append (drop_idx [0 ])
377
+ else :
378
+ missing_drops .append ((col_idx , val ))
379
+ continue
380
+
381
+ # val is nan, find nan in categories manually
382
+ for cat_idx , cat in enumerate (cat_list ):
383
+ if is_scalar_nan (cat ):
384
+ drop_indices .append (cat_idx )
385
+ break
386
+ else : # loop did not break thus drop is missing
387
+ missing_drops .append ((col_idx , val ))
388
+
360
389
if any (missing_drops ):
361
390
msg = ("The following categories were supposed to be "
362
391
"dropped, but were not found in the training "
@@ -365,10 +394,7 @@ def _compute_drop_idx(self):
365
394
["Category: {}, Feature: {}" .format (c , v )
366
395
for c , v in missing_drops ])))
367
396
raise ValueError (msg )
368
- return np .array ([np .where (cat_list == val )[0 ][0 ]
369
- for (val , cat_list ) in
370
- zip (self .drop , self .categories_ )],
371
- dtype = object )
397
+ return np .array (drop_indices , dtype = object )
372
398
373
399
def fit (self , X , y = None ):
374
400
"""
@@ -388,7 +414,8 @@ def fit(self, X, y=None):
388
414
self
389
415
"""
390
416
self ._validate_keywords ()
391
- self ._fit (X , handle_unknown = self .handle_unknown )
417
+ self ._fit (X , handle_unknown = self .handle_unknown ,
418
+ force_all_finite = 'allow-nan' )
392
419
self .drop_idx_ = self ._compute_drop_idx ()
393
420
return self
394
421
@@ -431,7 +458,8 @@ def transform(self, X):
431
458
"""
432
459
check_is_fitted (self )
433
460
# validation of X happens in _check_X called by _transform
434
- X_int , X_mask = self ._transform (X , handle_unknown = self .handle_unknown )
461
+ X_int , X_mask = self ._transform (X , handle_unknown = self .handle_unknown ,
462
+ force_all_finite = 'allow-nan' )
435
463
436
464
n_samples , n_features = X_int .shape
437
465
0 commit comments