8000 Inbetween adding the seen option · scikit-learn/scikit-learn@b5dcd0a · GitHub
[go: up one dir, main page]

Skip to content

Commit b5dcd0a

Browse files
Inbetween adding the seen option
1 parent 023d0b6 commit b5dcd0a

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

sklearn/preprocessing/data.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,8 +1744,13 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
17441744
17451745
Parameters
17461746
----------
1747-
values : 'auto', int, list of ints, or list of lists of objects
1748-
- 'auto' : determine set of values from training data.
1747+
values : 'auto', 'seen', int, list of ints, or list of lists of objects
1748+
- 'auto' : determine set of values from training data. If the input
1749+
is an int array, values are determined from range in
1750+
training data. For all other inputs, only values observed
1751+
during `fit` are considered valid values for each feature.
1752+
- 'seen': Only values observed during `fit` are considered valid
1753+
values for each feature.
17491754
- int : values are in ``range(values)`` for all features
17501755
- list of ints : values for feature ``i`` are in ``range(values[i])``
17511756
- list of lists : values for feature ``i`` are in ``values[i]``
@@ -1828,7 +1833,8 @@ def fit(self, X, y=None):
18281833
self
18291834
"""
18301835

1831-
X = check_array(X, dtype=np.object, accept_sparse='csc', copy=self.copy)
1836+
X = check_array(X, dtype=np.object, accept_sparse='csc',
1837+
copy=self.copy)
18321838
n_samples, n_features = X.shape
18331839

18341840
_apply_selected(X, self._fit, dtype=self.dtype,
@@ -1873,6 +1879,8 @@ def _fit(self, X):
18731879
for i in range(n_features):
18741880
le = self.label_encoders_[i]
18751881
if self.values == 'auto':
1882+
le.fit(np.arange(1 + np.max(X[:, i])))
1883+
elif self.values == 'seen':
18761884
le.fit(X[:, i])
18771885
elif isinstance(self.values, numbers.Integral):
18781886
if (np.max(X, axis=0) >= self.values).any():

0 commit comments

Comments
 (0)
0