8000 Merge pull request #2027 from mblondel/select_categorical · scikit-learn/scikit-learn@4fe51ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 4fe51ba

Browse files
committed
Merge pull request #2027 from mblondel/select_categorical
ENH Non-categorical variables in OneHotEncoder
2 parents d420aaf + e11873f commit 4fe51ba

File tree

3 files changed

+165
-28
lines changed

3 files changed

+165
-28
lines changed

doc/modules/preprocessing.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,15 +333,16 @@ Continuing the example above::
333333

334334
>>> enc = preprocessing.OneHotEncoder()
335335
>>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], [1, 0, 2]])
336-
OneHotEncoder(dtype=<type 'float'>, n_values='auto')
336+
OneHotEncoder(categorical_features='all', dtype=<type 'float'>,
337+
n_values='auto')
337338
>>> enc.transform([[0, 1, 3]]).toarray()
338339
array([[ 1., 0., 0., 1., 0., 0., 0., 0., 1.]])
339340

340341
By default, how many values each feature can take is inferred automatically from the dataset.
341-
It is possible to specify this explicitly using the parameter ``n_values``.
342+
It is possible to specify this explicitly using the parameter ``n_values``.
342343
There are two genders, three possible continents and four web browsers in our
343344
dataset.
344-
Then we fit the estimator, and transform a data point.
345+
Then we fit the estimator, and transform a data point.
345346
In the result, the first two numbers encode the gender, the next set of three
346347
numbers the continent and the last four the web browser.
347348

sklearn/preprocessing.py

Lines changed: 99 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212

1313
from .base import BaseEstimator, TransformerMixin
1414
from .externals.six import string_types
15-
from .utils import check_arrays, array2d, atleast2d_or_csr, safe_asarray
15+
from .utils import check_arrays
16+
from .utils import array2d
17+
from .utils import atleast2d_or_csr
18+
from .utils import atleast2d_or_csc
19+
from .utils import safe_asarray
1620
from .utils import warn_if_not_float
1721
from .utils.fixes import unique
1822

@@ -35,6 +39,7 @@
3539
'LabelEncoder',
3640
'MinMaxScaler',
3741
'Normalizer',
42+
'OneHotEncoder',
3843
'StandardScaler',
3944
'binarize',
4045
'normalize',
@@ -632,6 +637,53 @@ def transform(self, X, y=None, copy=None):
632637
return binarize(X, threshold=self.threshold, copy=copy)
633638

634639

640+
def _transform_selected(X, transform, selected="all"):
641+
"""Apply a transform function to portion of selected features
642+
643+
Parameters
644+
----------
645+
X : array-like or sparse matrix, shape=(n_samples, n_features)
646+
Dense array or sparse matrix.
647+
648+
transform : callable
649+
A callable transform(X) -> X_transformed
650+
651+
selected: "all" or array of indices or mask
652+
Specify what features to apply the transform to.
653+
654+
Returns
655+
-------
656+
X : array or sparse matrix, shape=(n_samples, n_features_new)
657+
"""
658+
if selected == "all":
659+
return transform(X)
660+
elif len(selected) == 0:
661+
return X
662+
else:
663+
X = atleast2d_or_csc(X)
664+
n_features = X.shape[1]
665+
ind = np.arange(n_features)
666+
sel = np.zeros(n_features, dtype=bool)
667+
sel[np.array(selected)] = True
668+
not_sel = np.logical_not(sel)
669+
n_selected = np.sum(sel)
670+
671+
if n_selected == 0:
672+
# No features selected.
673+
return X
674+
elif n_selected == n_features:
675+
# All features selected.
676+
return transform(X)
677+
else:
678+
X_sel = transform(X[:, ind[sel]])
679+
X_not_sel = X[:, ind[not_sel]]
680+
681+
if sp.issparse(X_sel) or sp.issparse(X_not_sel):
682+
return sp.hstack((X_sel, X_not_sel))
683+
else:
684+
return np.hstack((X_sel, X_not_sel))
685+
686+
635687
class OneHotEncoder(BaseEstimator, TransformerMixin):
636688
"""Encode categorical integer features using a one-hot aka one-of-K scheme.
637689
@@ -646,11 +698,21 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
646698
647699
Parameters
648700
----------
649-
n_values : 'auto', int or array of int
701+
n_values : 'auto', int or array of ints
650702
Number of values per feature.
651-
'auto' : determine value range from training data.
652-
int : maximum value for all features.
653-
array : maximum value per feature.
703+
704+
- 'auto' : determine value range from training data.
705+
- int : maximum value for all features.
706+
- array : maximum value per feature.
707+
708+
categorical_features: "all" or array of indices or mask
709+
Specify what features are treated as categorical.
710+
711+
- 'all' (default): All features are treated as categorical.
712+
- array of indices: Array of categorical feature indices.
713+
- mask: Array of length n_features and with dtype=bool.
714+
715+
Non-categorical features are always stacked to the right of the matrix.
654716
655717
dtype : number type, default=np.float
656718
Desired dtype of output.
@@ -680,7 +742,8 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
680742
>>> enc = OneHotEncoder()
681743
>>> enc.fit([[0, 0, 3], [1, 1, 0], [0, 2, 1], 10000 \
682744
[1, 0, 2]]) # doctest: +ELLIPSIS
683-
OneHotEncoder(dtype=<... 'float'>, n_values='auto')
745+
OneHotEncoder(categorical_features='all', dtype=<type 'float'>,
746+
n_values='auto')
684747
>>> enc.n_values_
685748
array([2, 3, 4])
686749
>>> enc.feature_indices_
@@ -690,12 +753,13 @@ class OneHotEncoder(BaseEstimator, TransformerMixin):
690753
691754
See also
692755
--------
693-
LabelEncoder : performs a one-hot encoding on arbitrary class labels.
694756
sklearn.feature_extraction.DictVectorizer : performs a one-hot encoding of
695757
dictionary items (also handles string-valued features).
696758
"""
697-
def __init__(self, n_values="auto", dtype=np.float):
759+
def __init__(self, n_values="auto", categorical_features="all",
760+
dtype=np.float):
698761
self.n_values = n_values
762+
self.categorical_features = categorical_features
699763
self.dtype = dtype
700764

701765
def fit(self, X, y=None):
@@ -713,12 +777,8 @@ def fit(self, X, y=None):
713777
self.fit_transform(X)
714778
return self
715779

716-
def fit_transform(self, X, y=None):
717-
"""Fit OneHotEncoder to X, then transform X.
718-
719-
Equivalent to self.fit(X).transform(X), but more convenient and more
720-
efficient. See fit for the parameters, transform for the return value.
721-
"""
780+
def _fit_transform(self, X):
781+
"""Asssumes X contains only categorical features."""
722782
X = check_arrays(X, sparse_format='dense', dtype=np.int)[0]
723783
if np.any(X < 0):
724784
raise ValueError("X needs to contain only non-negative integers.")
@@ -759,19 +819,17 @@ def fit_transform(self, X, y=None):
759819

760820
return out
761821

762-
def transform(self, X):
763-
"""Transform X using one-hot encoding.
764-
765-
Parameters
766-
----------
767-
X : array-like, shape=(n_samples, feature_indices_[-1])
768-
Input array of type int.
822+
def fit_transform(self, X, y=None):
823+
"""Fit OneHotEncoder to X, then transform X.
769824
770-
Returns
771-
-------
772-
X_out : sparse matrix, dtype=int
773-
Transformed input.
825+
Equivalent to self.fit(X).transform(X), but more convenient and more
826+
efficient. See fit for the parameters, transform for the return value.
774827
"""
828+
return _transform_selected(X, self._fit_transform,
829+
self.categorical_features)
830+
831+
def _transform(self, X):
832+
"""Asssumes X contains only categorical features."""
775833
X = check_arrays(X, sparse_format='dense', dtype=np.int)[0]
776834
if np.any(X < 0):
777835
raise ValueError("X needs to contain only non-negative integers.")
@@ -798,6 +856,22 @@ def transform(self, X):
798856
out = out[:, self.active_features_]
799857
return out
800858

859+
def transform(self, X):
860+
"""Transform X using one-hot encoding.
861+
862+
Parameters
863+
----------
< F438 code>864+
X : array-like, shape=(n_samples, n_features)
865+
Input array of type int.
866+
867+
Returns
868+
-------
869+
X_out : sparse matrix, dtype=int
870+
Transformed input.
871+
"""
872+
return _transform_selected(X, self._transform,
873+
self.categorical_features)
874+
801875

802876
class LabelEncoder(BaseEstimator, TransformerMixin):
803877
"""Encode labels with value between 0 and n_classes-1.

sklearn/tests/test_preprocessing.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from sklearn.preprocessing import Binarizer
1616
from sklearn.preprocessing import KernelCenterer
1717
from sklearn.preprocessing import LabelBinarizer
18+
from sklearn.preprocessing import _transform_selected
1819
from sklearn.preprocessing import OneHotEncoder
1920
from sklearn.preprocessing import LabelEncoder
2021
from sklearn.preprocessing import Normalizer
@@ -612,6 +613,67 @@ def test_one_hot_encoder():
612613
assert_raises(ValueError, enc.transform, [[0], [-1]])
613614

614615

616+
def _check_transform_selected(X, Xexpected, sel):
617+
for M in (X, sp.csr_matrix(X)):
618+
Xtr = _transform_selected(M, Binarizer().transform, sel)
619+
assert_array_equal(toarray(Xtr), Xexpected)
620+
621+
622+
def test_transform_selected():
623+
X = [[3, 2, 1], [0, 1, 1]]
624+
625+
Xexpected = [[1, 2, 1], [0, 1, 1]]
626+
_check_transform_selected(X, Xexpected, [0])
627+
_check_transform_selected(X, Xexpected, [True, False, False])
628+
629+
Xexpected = [[1, 1, 1], [0, 1, 1]]
630+
_check_transform_selected(X, Xexpected, [0, 1, 2])
631+
_check_transform_selected(X, Xexpected, [True, True, True])
632+
_check_transform_selected(X, Xexpected, "all")
633+
634+
_check_transform_selected(X, X, [])
635+
_check_transform_selected(X, X, [False, False, False])
636+
637+
638+
def _run_one_hot(X, X2, cat):
639+
enc = OneHotEncoder(categorical_features=cat)
640+
Xtr = enc.fit_transform(X)
641+
X2tr = enc.transform(X2)
642+
return Xtr, X2tr
643+
644+
645+
def _check_one_hot(X, X2, cat, n_features):
646+
ind = np.where(cat)[0]
647+
# With mask
648+
A, B = _run_one_hot(X, X2, cat)
649+
# With indices
650+
C, D = _run_one_hot(X, X2, ind)
651+
# Check shape
652+
assert_equal(A.shape, (2, n_features))
653+
assert_equal(B.shape, (1, n_features))
654+
assert_equal(C.shape, (2, n_features))
655+
assert_equal(D.shape, (1, n_features))
656+
# Check that mask and indices give the same results
657+
assert_array_equal(toarray(A), toarray(C))
658+
assert_array_equal(toarray(B), toarray(D))
659+
660+
661+
def test_one_hot_encoder_categorical_features():
662+
X = np.array([[3, 2, 1], [0, 1, 1]])
663+
X2 = np.array([[1, 1, 1]])
664+
665+
cat = [True, False, False]
666+
_check_one_hot(X, X2, cat, 4)
667+
668+
# Edge case: all non-categorical
669+
cat = [False, False, False]
670+
_check_one_hot(X, X2, cat, 3)
671+
672+
# Edge case: all categorical
673+
cat = [True, True, True]
674+
_check_one_hot(X, X2, cat, 5)
675+
676+
615677
def test_label_encoder():
616678
"""Test LabelEncoder's transform and inverse_transform methods"""
617679
le = LabelEncoder()

0 commit comments

Comments
 (0)
0