8000 ENH: process DataFrames in OneHot/OrdinalEncoder without converting t… · scikit-learn/scikit-learn@b2344a4 · GitHub
[go: up one dir, main page]

Skip to content

Commit b2344a4

Browse files
maikiajorisvandenbossche
authored andcommitted
ENH: process DataFrames in OneHot/OrdinalEncoder without converting to array #12147 (#13253)
1 parent 04a5733 commit b2344a4

File tree

2 files changed

+87
-26
lines changed

2 files changed

+87
-26
lines changed

sklearn/preprocessing/_encoders.py

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,39 +38,64 @@ def _check_X(self, X):
3838
- convert list of strings to object dtype
3939
- check for missing values for object dtype data (check_array does
4040
not do that)
41+
- return list of features (arrays): this list of features is
42+
constructed feature by feature to preserve the data types
43+
of pandas DataFrame columns, as otherwise information is lost
44+
and cannot be used, eg for the `categories_` attribute.
4145
4246
"""
43-
X_temp = check_array(X, dtype=None)
44-
if not hasattr(X, 'dtype') and np.issubdtype(X_temp.dtype, np.str_):
45-
X = check_array(X, dtype=np.object)
47+
if not (hasattr(X, 'iloc') and getattr(X, 'ndim', 0) == 2):
48+
# if not a dataframe, do normal check_array validation
49+
X_temp = check_array(X, dtype=None)
50+
if (not hasattr(X, 'dtype')
51+
and np.issubdtype(X_temp.dtype, np.str_)):
52+
X = check_array(X, dtype=np.object)
53+
else:
54+
X = X_temp
55+
needs_validation = False
4656
else:
47-
X = X_temp
57+
# pandas dataframe, do validation later column by column, in order
58+
# to keep the dtype information to be used in the encoder.
59+
needs_validation = True
4860

49-
return X
61+
n_samples, n_features = X.shape
62+
X_columns = []
5063

51-
def _fit(self, X, handle_unknown='error'):
52-
X = self._check_X(X)
64+
for i in range(n_features):
65+
Xi = self._get_feature(X, feature_idx=i)
66+
Xi = check_array(Xi, ensure_2d=False, dtype=None,
67+
force_all_finite=needs_validation)
68+
X_columns.append(Xi)
5369

54-
n_samples, n_features = X.shape
70+
return X_columns, n_samples, n_features
71+
72+
def _get_feature(self, X, feature_idx):
73+
if hasattr(X, 'iloc'):
74+
# pandas dataframes
75+
return X.iloc[:, feature_idx]
76+
# numpy arrays, sparse arrays
77+
return X[:, feature_idx]
78+
79+
def _fit(self, X, handle_unknown='error'):
80+
X_list, n_samples, n_features = self._check_X(X)
5581

5682
if self._categories != 'auto':
57-
if X.dtype != object:
58-
for cats in self._categories:
59-
if not np.all(np.sort(cats) == np.array(cats)):
60-
raise ValueError("Unsorted categories are not "
61-
"supported for numerical categories")
6283
if len(self._categories) != n_features:
6384
raise ValueError("Shape mismatch: if n_values is an array,"
6485
" it has to be of shape (n_features,).")
6586

6687
self.categories_ = []
6788

6889
for i in range(n_features):
69-
Xi = X[:, i]
90+
Xi = X_list[i]
7091
if self._categories == 'auto':
7192
cats = _encode(Xi)
7293
else:
73-
cats = np.array(self._categories[i], dtype=X.dtype)
94+
cats = np.array(self._categories[i], dtype=Xi.dtype)
95+
if Xi.dtype != object:
96+
if not np.all(np.sort(cats) == cats):
97+
raise ValueError("Unsorted categories are not "
98+
"supported for numerical categories")
7499
if handle_unknown == 'error':
75100
diff = _encode_check_unknown(Xi, cats)
76101
if diff:
@@ -80,14 +105,13 @@ def _fit(self, X, handle_unknown='error'):
80105
self.categories_.append(cats)
81106

82107
def _transform(self, X, handle_unknown='error'):
83-
X = self._check_X(X)
108+
X_list, n_samples, n_features = self._check_X(X)
84109

85-
_, n_features = X.shape
86-
X_int = np.zeros_like(X, dtype=np.int)
87-
X_mask = np.ones_like(X, dtype=np.bool)
110+
X_int = np.zeros((n_samples, n_features), dtype=np.int)
111+
X_mask = np.ones((n_samples, n_features), dtype=np.bool)
88112

89113
for i in range(n_features):
90-
Xi = X[:, i]
114+
Xi = X_list[i]
91115
diff, valid_mask = _encode_check_unknown(Xi, self.categories_[i],
92116
return_mask=True)
93117

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,30 @@ def test_one_hot_encoder_inverse(sparse_, drop):
431431
assert_raises_regex(ValueError, msg, enc.inverse_transform, X_tr)
432432

433433

434+
@pytest.mark.parametrize("method", ['fit', 'fit_transform'])
435+
@pytest.mark.parametrize("X", [
436+
[1, 2],
437+
np.array([3., 4.])
438+
])
439+
def test_X_is_not_1D(X, method):
440+
oh = OneHotEncoder()
441+
442+
msg = ("Expected 2D array, got 1D array instead")
443+
with pytest.raises(ValueError, match=msg):
444+
getattr(oh, method)(X)
445+
446+
447+
@pytest.mark.parametrize("method", ['fit', 'fit_transform'])
448+
def test_X_is_not_1D_pandas(method):
449+
pd = pytest.importorskip('pandas')
450+
X = pd.Series([6, 3, 4, 6])
451+
oh = OneHotEncoder()
452+
453+
msg = ("Expected 2D array, got 1D array instead")
454+
with pytest.raises(ValueError, match=msg):
455+
getattr(oh, method)(X)
456+
457+
434458
@pytest.mark.parametrize("X, cat_exp, cat_dtype", [
435459
([['abc', 55], ['def', 55]], [['abc', 'def'], [55]], np.object_),
436460
(np.array([[1, 2], [3, 2]]), [[1, 3], [2]], np.integer),
@@ -569,8 +593,14 @@ def test_one_hot_encoder_feature_names_unicode():
569593
@pytest.mark.parametrize("X", [np.array([[1, np.nan]]).T,
570594
np.array([['a', np.nan]], dtype=object).T],
571595
ids=['numeric', 'object'])
596+
@pytest.mark.parametrize("as_data_frame", [False, True],
597+
ids=['array', 'dataframe'])
572598
@pytest.mark.parametrize("handle_unknown", ['error', 'ignore'])
573-
def test_one_hot_encoder_raise_missing(X, handle_unknown):
599+
def test_one_hot_encoder_raise_missing(X, as_data_frame, handle_unknown):
600+
if as_data_frame:
601+
pd = pytest.importorskip('pandas')
602+
X = pd.DataFrame(X)
603+
574604
ohe = OneHotEncoder(categories='auto', handle_unknown=handle_unknown)
575605

576606
with pytest.raises(ValueError, match="Input contains NaN"):
@@ -579,7 +609,12 @@ def test_one_hot_encoder_raise_missing(X, handle_unknown):
579609
with pytest.raises(ValueError, match="Input contains NaN"):
580610
ohe.fit_transform(X)
581611

582-
ohe.fit(X[:1, :])
612+
if as_data_frame:
613+
X_partial = X.iloc[:1, :]
614+
else:
615+
X_partial = X[:1, :]
616+
617+
ohe.fit(X_partial)
583618

584619
with pytest.raises(ValueError, match="Input contains NaN"):
585620
ohe.transform(X)
@@ -688,16 +723,18 @@ def test_encoder_dtypes_pandas():
688723
pd = pytest.importorskip('pandas')
689724

690725
enc = OneHotEncoder(categories='auto')
691-
exp = np.array([[1., 0., 1., 0.], [0., 1., 0., 1.]], dtype='float64')
726+
exp = np.array([[1., 0., 1., 0., 1., 0.],
727+
[0., 1., 0., 1., 0., 1.]], dtype='float64')
692728

693-
X = pd.DataFrame({'A': [1, 2], 'B': [3, 4]}, dtype='int64')
729+
X = pd.DataFrame({'A': [1, 2], 'B': [3, 4], 'C': [5, 6]}, dtype='int64')
694730
enc.fit(X)
695731
assert all([enc.categories_[i].dtype == 'int64' for i in range(2)])
696732
assert_array_equal(enc.transform(X).toarray(), exp)
697733

698-
X = pd.DataFrame({'A': [1, 2], 'B': ['a', 'b']})
734+
X = pd.DataFrame({'A': [1, 2], 'B': ['a', 'b'], 'C': [3., 4.]})
735+
X_type = [int, object, float]
699736
enc.fit(X)
700-
assert all([enc.categories_[i].dtype == 'object' for i in range(2)])
737+
assert all([enc.categories_[i].dtype == X_type[i] for i in range(3)])
701738
assert_array_equal(enc.transform(X).toarray(), exp)
702739

703740

0 commit comments

Comments
 (0)
0