8000 FIX incorrect error when OneHotEncoder.transform called prior to fit … · thoo/scikit-learn@509604f · GitHub
[go: up one dir, main page]

Skip to content

Commit 509604f

Browse files
dillongardnerthoo
authored andcommitted
FIX incorrect error when OneHotEncoder.transform called prior to fit (scikit-learn#12443)
1 parent 47ee912 commit 509604f

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

sklearn/preprocessing/_encoders.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,8 @@
2121
from .base import _transform_selected
2222
from .label import _encode, _encode_check_unknown
2323

24-
2524
range = six.moves.range
2625

27-
2826
__all__ = [
2927
'OneHotEncoder',
3028
'OrdinalEncoder'
@@ -383,6 +381,12 @@ def _handle_deprecations(self, X):
383381
"The 'categorical_features' keyword is deprecated in "
384382
"version 0.20 and will be removed in 0.22. You can "
385383
"use the ColumnTransformer instead.", DeprecationWarning)
384+
# Set categories_ to empty list if no categorical columns exist
385+
n_features = X.shape[1]
386+
sel = np.zeros(n_features, dtype=bool)
387+
sel[np.asarray(self.categorical_features)] = True
388+
if sum(sel) == 0:
389+
self.categories_ = []
386390
self._legacy_mode = True
387391
self._categorical_features = self.categorical_features
388392
else:
@@ -591,6 +595,7 @@ def transform(self, X):
591595
X_out : sparse matrix if sparse=True else a 2-d array
592596
Transformed input.
593597
"""
598+
check_is_fitted(self, 'categories_')
594599
if self._legacy_mode:
595600
return _transform_selected(X, self._legacy_transform, self.dtype,
596601
self._categorical_features,
@@ -683,7 +688,7 @@ def get_feature_names(self, input_features=None):
683688
cats = self.categories_
684689
if input_features is None:
685690
input_features = ['x%d' % i for i in range(len(cats))]
686-
elif(len(input_features) != len(self.categories_)):
691+
elif len(input_features) != len(self.categories_):
687692
raise ValueError(
688693
"input_features should have length equal to number of "
689694
"features ({}), got {}".format(len(self.categories_),

sklearn/preprocessing/tests/test_encoders.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from scipy import sparse
88
import pytest
99

10+
from sklearn.exceptions import NotFittedError
1011
from sklearn.utils.testing import assert_array_equal
1112
from sklearn.utils.testing import assert_equal
1213
from sklearn.utils.testing import assert_raises
@@ -250,6 +251,28 @@ def test_one_hot_encoder_handle_unknown():
250251
assert_raises(ValueError, oh.fit, X)
251252

252253

254+
def test_one_hot_encoder_not_fitted():
255+
X = np.array([['a'], ['b']])
256+
enc = OneHotEncoder(categories=['a', 'b'])
257+
msg = ("This OneHotEncoder instance is not fitted yet. "
258+
"Call 'fit' with appropriate arguments before using this method.")
259+
with pytest.raises(NotFittedError, match=msg):
260+
enc.transform(X)
261+
262+
263+
def test_one_hot_encoder_no_categorical_features():
264+
X = np.array([[3, 2, 1], [0, 1, 1]], dtype='float64')
265+
266+
cat = [False, False, False]
267+
enc = OneHotEncoder(categorical_features=cat)
268+
with ignore_warnings(category=(DeprecationWarning, FutureWarning)):
269+
X_tr = enc.fit_transform(X)
270+
expected_features = np.array(list(), dtype='object')
271+
assert_array_equal(X, X_tr)
272+
assert_array_equal(enc.get_feature_names(), expected_features)
273+
assert enc.categories_ == []
274+
275+
253276
@pytest.mark.parametrize("output_dtype", [np.int32, np.float32, np.float64])
254277
@pytest.mark.parametrize("input_dtype", [np.int32, np.float32, np.float64])
255278
def test_one_hot_encoder_dtype(input_dtype, output_dtype):

0 commit comments

Comments
 (0)
0