-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
FEA Add make_column_selector for ColumnTransformer #12371
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3c2f8f4
ff2c7e7
cbd1169
dca0726
fac1f55
ae8ed2f
90e3bde
ff93079
71b5bf5
56c68e9
4ea1f79
81a6ccb
149b96a
4b69b61
2863fab
59f281e
d535b66
7b56e46
5893a78
3f9ce99
3fa1f6d
f4feaea
d970fd2
4d651c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,19 +2,23 @@ | |
Test the ColumnTransformer. | ||
""" | ||
import re | ||
import pickle | ||
|
||
import warnings | ||
import numpy as np | ||
from scipy import sparse | ||
import pytest | ||
|
||
from numpy.testing import assert_allclose | ||
from sklearn.utils._testing import assert_raise_message | ||
from sklearn.utils._testing import assert_array_equal | ||
from sklearn.utils._testing import assert_allclose_dense_sparse | ||
from sklearn.utils._testing import assert_almost_equal | ||
|
||
from sklearn.base import BaseEstimator | ||
from sklearn.compose import ColumnTransformer, make_column_transformer | ||
from sklearn.compose import ( | ||
ColumnTransformer, make_column_transformer, make_column_selector | ||
) | ||
from sklearn.exceptions import NotFittedError | ||
from sklearn.preprocessing import FunctionTransformer | ||
from sklearn.preprocessing import StandardScaler, Normalizer, OneHotEncoder | ||
|
@@ -1180,3 +1184,85 @@ def test_column_transformer_mask_indexing(array_type): | |
) | ||
X_trans = column_transformer.fit_transform(X) | ||
assert X_trans.shape == (3, 2) | ||
|
||
|
||
@pytest.mark.parametrize('cols, pattern, include, exclude', [ | ||
(['col_int', 'col_float'], None, np.number, None), | ||
(['col_int', 'col_float'], None, None, object), | ||
(['col_int', 'col_float'], None, [np.int, np.float], None), | ||
(['col_str'], None, [np.object], None), | ||
(['col_str'], None, np.object, None), | ||
(['col_float'], None, float, None), | ||
(['col_float'], 'at$', [np.number], None), | ||
(['col_int'], None, [np.int], None), | ||
(['col_int'], '^col_int', [np.number], None), | ||
(['col_float', 'col_str'], 'float|str', None, None), | ||
(['col_str'], '^col_s', None, [np.int]), | ||
([], 'str$', np.float, None), | ||
(['col_int', 'col_float', 'col_str'], None, [np.number, np.object], None), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a test where none of the conditions are met? That should just return [] right? |
||
]) | ||
def test_make_column_selector_with_select_dtypes(cols, pattern, include, | ||
exclude): | ||
pd = pytest.importorskip('pandas') | ||
|
||
X_df = pd.DataFrame({ | ||
'col_int': np.array([0, 1, 2], dtype=np.int), | ||
'col_float': np.array([0.0, 1.0, 2.0], dtype=np.float), | ||
'col_str': ["one", "two", "three"], | ||
}, columns=['col_int', 'col_float', 'col_str']) | ||
|
||
selector = make_column_selector( | ||
dtype_include=include, dtype_exclude=exclude, pattern=pattern) | ||
|
||
assert_array_equal(selector(X_df), cols) | ||
|
||
|
||
def test_column_transformer_with_make_column_selector(): | ||
# Functional test for column transformer + column selector | ||
pd = pytest.importorskip('pandas') | ||
X_df = pd.DataFrame({ | ||
'col_int': np.array([0, 1, 2], dtype=np.int), | ||
'col_float': np.array([0.0, 1.0, 2.0], dtype=np.float), | ||
'col_cat': ["one", "two", "one"], | ||
'col_str': ["low", "middle", "high"] | ||
}, columns=['col_int', 'col_float', 'col_cat', 'col_str']) | ||
X_df['col_str'] = X_df['col_str'].astype('category') | ||
|
||
cat_selector = make_column_selector(dtype_include=['category', object]) | ||
num_selector = make_column_selector(dtype_include=np.number) | ||
|
||
ohe = OneHotEncoder() | ||
scaler = StandardScaler() | ||
|
||
ct_selector = make_column_transformer((ohe, cat_selector), | ||
(scaler, num_selector)) | ||
ct_direct = make_column_transformer((ohe, ['col_cat', 'col_str']), | ||
(scaler, ['col_float', 'col_int'])) | ||
|
||
X_selector = ct_selector.fit_transform(X_df) | ||
X_direct = ct_direct.fit_transform(X_df) | ||
|
||
assert_allclose(X_selector, X_direct) | ||
|
||
|
||
def test_make_column_selector_error(): | ||
selector = make_column_selector(dtype_include=np.number) | ||
X = np.array([[0.1, 0.2]]) | ||
msg = ("make_column_selector can only be applied to pandas dataframes") | ||
with pytest.raises(ValueError, match=msg): | ||
selector(X) | ||
|
||
|
||
def test_make_column_selector_pickle(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this test needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
pd = pytest.importorskip('pandas') | ||
|
||
X_df = pd.DataFrame({ | ||
'col_int': np.array([0, 1, 2], dtype=np.int), | ||
'col_float': np.array([0.0, 1.0, 2.0], dtype=np.float), | ||
'col_str': ["one", "two", "three"], | ||
}, columns=['col_int', 'col_float', 'col_str']) | ||
|
||
selector = make_column_selector(dtype_include=[object]) | ||
selector_picked = pickle.loads(pickle.dumps(selector)) | ||
|
||
assert_array_equal(selector(X_df), selector_picked(X_df)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need another see also in the ColumnTransformer class to link there.