8000 Address @jnothman's comments. · scikit-learn/scikit-learn@e11873f · GitHub
[go: up one dir, main page]

Skip to content

Commit e11873f

Browse files
committed
Address @jnothman's comments.
1 parent 85e4a5f commit e11873f

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

sklearn/preprocessing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
'LabelEncoder',
4040
'MinMaxScaler',
4141
'Normalizer',
42+
'OneHotEncoder',
4243
'StandardScaler',
4344
'binarize',
4445
'normalize',
@@ -635,7 +636,7 @@ def transform(self, X, y=None, copy=None):
635636
return binarize(X, threshold=self.threshold, copy=copy)
636637

637638

638-
def _transform_selected(X, transform, selected):
639+
def _transform_selected(X, transform, selected="all"):
639640
"""Apply a transform function to portion of selected features
640641
641642
Parameters
@@ -653,10 +654,10 @@ def _transform_selected(X, transform, selected):
653654
-------
654655
X : array or sparse matrix, shape=(n_samples, n_features_new)
655656
"""
656-
if len(selected) == 0:
657-
return X
658-
elif selected == "all":
657+
if selected == "all":
659658
return transform(X)
659+
elif len(selected) == 0:
660+
return X
660661
else:
661662
X = atleast2d_or_csc(X)
662663
n_features = X.shape[1]

sklearn/tests/test_preprocessing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ def test_transform_selected():
629629
Xexpected = [[1, 1, 1], [0, 1, 1]]
630630
_check_transform_selected(X, Xexpected, [0, 1, 2])
631631
_check_transform_selected(X, Xexpected, [True, True, True])
632+
_check_transform_selected(X, Xexpected, "all")
632633

633634
_check_transform_selected(X, X, [])
634635
_check_transform_selected(X, X, [False, False, False])
@@ -656,6 +657,7 @@ def _check_one_hot(X, X2, cat, n_features):
656657
assert_array_equal(toarray(A), toarray(C))
657658
assert_array_equal(toarray(B), toarray(D))
658659

660+
659661
def test_one_hot_encoder_categorical_features():
660662
X = np.array([[3, 2, 1], [0, 1, 1]])
661663
X2 = np.array([[1, 1, 1]])

0 commit comments

Comments
 (0)
0