8000 remove seen argument and support range case with FutureWarning · scikit-learn/scikit-learn@398070f · GitHub
[go: up one dir, main page]

Skip to content

Commit 398070f

Browse files
remove seen argument and support range case with FutureWarning
1 parent b5dcd0a commit 398070f

File tree

2 files changed

+43
-18
lines changed

2 files changed

+43
-18
lines changed

sklearn/preprocessing/data.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1745,12 +1745,9 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
17451745
Parameters
17461746
----------
17471747
values : 'auto', 'seen', int, list of ints, or list of lists of objects
1748-
- 'auto' : determine set of values from training data. If the input
1749-
is an int array, values are determined from range in
1750-
training data. For all other inputs, only values observed
1751-
during `fit` are considered valid values for each feature.
1752-
- 'seen': Only values observed during `fit` are considered valid
1753-
values for each feature.
1748+
- 'auto' : determine set of values from training data. See the
1749+
documentation of `handle_unknown` for which values are considered
1750+
acceptable.
17541751
- int : values are in ``range(values)`` for all features
17551752
- list of ints : values for feature ``i`` are in ``range(values[i])``
17561753
- list of lists : values for feature ``i`` are in ``values[i]``
@@ -1771,8 +1768,12 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
17711768
Will return sparse matrix if set True else will return an array.
17721769
17731770
handle_unknown : str, 'error' or 'ignore'
1774-
Whether to raise an error or ignore if a unknown categorical feature is
1775-
present during transform.
1771+
1772+
- 'ignore': Ignore all unknown feature values.
1773+
- 'error': Raise an error when the value of a feature is unseen during
1774+
`fit` and out of range of values seen during `fit`.
1775+
- 'error-strict': Raise an error when the value of a feature is unseen
1776+
during`fit`.
17761777
17771778
copy : bool, default=True
17781779
If unset, `X` maybe modified in space.
@@ -1850,6 +1851,8 @@ def _fit(self, X):
18501851

18511852
self._n_features = n_features
18521853
self.label_encoders_ = [LabelEncoder() for i in range(n_features)]
1854+
# Maximum value for each featue
1855+
self._max_values = [None for i in range(n_features)]
18531856

18541857
if self.n_values is not None:
18551858
warnings.warn('The parameter `n_values` is deprecated, use the'
@@ -1878,9 +1881,9 @@ def _fit(self, X):
18781881

18791882
for i in range(n_features):
18801883
le = self.label_encoders_[i]
1884+
1885+
self._max_values[i] = np.max(X[:, i])
18811886
if self.values == 'auto':
1882-
le.fit(np.arange(1 + np.max(X[:, i])))
1883-
elif self.values == 'seen':
18841887
le.fit(X[:, i])
18851888
elif isinstance(self.values, numbers.Integral):
18861889
if (np.max(X, axis=0) >= self.values).any():
@@ -1931,14 +1934,27 @@ def _transform(self, X):
19311934
valid_mask = in1d(X[:, i], self.label_encoders_[i].classes_)
19321935

19331936
if not np.all(valid_mask):
1934-
1935-
if self.handle_unknown == 'error':
1937+
if self.handle_unknown in ['error', 'error-strict']:
19361938
diff = setdiff1d(X[:, i], self.label_encoders_[i].classes_)
1937-
msg = 'Unknown feature(s) %s in column %d' % (diff, i)
1938-
raise ValueError(msg)
1939+
if self.handle_unknown == 'error-strict':
1940+
msg = 'Unknown feature(s) %s in column %d' % (diff, i)
1941+
raise ValueError(msg)
1942+
else:
1943+
if np.all(diff <= self._max_values[i]):
1944+
msg = ('Values %s for feature %d are unknown but '
1945+
'in range. This will raise an error in '
1946+
'future versions.' % (str(diff), i))
1947+
warnings.warn(FutureWarning(msg))
1948+
X_mask[:, i] = valid_mask
1949+
le = self.label_encoders_[i]
1950+
X[:, i][~valid_mask] = le.classes_[0]
1951+
else:
1952+
msg = ('Unknown feature(s) %s in column %d' %
1953+
(diff, i))
1954+
raise ValueError(msg)
19391955
elif self.handle_unknown == 'ignore':
19401956
# Set the problematic rows to an acceptable value and
1941-
# continue `The rows are marked in `X_mask` and will be
1957+
# continue. The rows are marked in `X_mask` and will be
19421958
# removed later.
19431959
X_mask[:, i] = valid_mask
19441960
X[:, i][~valid_mask] = self.label_encoders_[i].classes_[0]

sklearn/preprocessing/tests/test_data.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ def test_one_hot_encoder_string():
16051605

16061606
def test_one_hot_encoder_categorical_features():
16071607
X = np.array([[3, 2, 1], [0, 1, 1]])
1608-
X2 = np.array([[3, 1, 1]])
1608+
X2 = np.array([[1, 1, 1]])
16091609

16101610
cat = [True, False, False]
16111611
_check_one_hot(X, X2, cat, 4)
@@ -1625,10 +1625,19 @@ def test_one_hot_encoder_unknown_transform():
16251625

16261626
# Test that one hot encoder raises error for unknown features
16271627
# present during transform.
1628-
oh = OneHotEncoder(handle_unknown='error')
1628+
oh = OneHotEncoder(handle_unknown='error-strict')
16291629
oh.fit(X)
16301630
assert_raises(ValueError, oh.transform, y)
16311631

1632+
# Test that one hot encoder raises warning for unknown but in range
1633+
# features
1634+
oh = OneHotEncoder(handle_unknown='error')
1635+
oh.fit(X)
1636+
msg = ('Values [0] for feature 2 are unknown but in range. '
1637+
'This will raise an error in future versions.')
1638+
assert_warns_message(FutureWarning, msg, oh.transform,
1639+
np.array([[0, 0, 0]]))
1640+
16321641
# Test the ignore option, ignores unknown features.
16331642
oh = OneHotEncoder(handle_unknown='ignore')
16341643
oh.fit(X)
@@ -1641,7 +1650,7 @@ def test_one_hot_encoder_unknown_transform():
16411650

16421651
# Test that one hot encoder raises error for unknown features
16431652
# present during transform.
1644-
oh = OneHotEncoder(handle_unknown='error')
1653+
oh = OneHotEncoder(handle_unknown='error-strict')
16451654
oh.fit(X)
16461655
assert_raises(ValueError, oh.transform, y)
16471656

0 commit comments

Comments
 (0)
0