diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index bb84d25b59df9..76f7b7dbeee0c 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -392,6 +392,7 @@ partial dependence feature_extraction.DictVectorizer feature_extraction.FeatureHasher + feature_extraction.ColumnTransformer From images ----------- diff --git a/doc/modules/feature_extraction.rst b/doc/modules/feature_extraction.rst index 9d15cda6f69b9..40ff6c1bb491d 100644 --- a/doc/modules/feature_extraction.rst +++ b/doc/modules/feature_extraction.rst @@ -100,6 +100,62 @@ of the time. So as to make the resulting data structure able to fit in memory the ``DictVectorizer`` class uses a ``scipy.sparse`` matrix by default instead of a ``numpy.ndarray``. +.. _column_transformer: + +Columnar Data +============= +Many datasets contain features of different types, say text, floats and dates, +where each type of feature requires separate preprocessing. +Often it is easiest to preprocess data before applying scikit-learn methods, for example using +pandas. +If the preprocessing has parameters that you want to adjust within a +grid-search, however, they need to be inside a transformer. This can be +achieved very simply with the :class:`ColumnTransformer`. The +:class:`ColumnTransformer` works on pandas dataframe, dictionaries, and other +objects that implement ``getattr`` so select a certain attribute or column. + +.. note:: + :class:`ColumnTransformer` expects a very different data format from the numpy arrays usually used in scikit-learn. + For a numpy array ``X_array``, ``X_array[1]`` will give a single sample (``X_array[1].shape == (n_samples.)``), but all features. + For columnar data like a dict or pandas dataframe ``X_columns``, ``X_columns[1]`` is expected to give a feature called + ``1`` for each sample (``X_columns[1].shape == (n_samples,)``). + +To each column, a different transformation can be applied, such as +preprocessing or a specific feature extraction method:: + + >>> X = {'city': ['London', 'London', 'Paris', 'New York'], + ... 'title': ["His Last Bow", "How Watson Learned the Trick", "A Moveable Feast", "The Great Gatsby"]} + +In contrast to the :class:`DictVectorizer` here the whole dataset is a dict, +with each value having the same lenght ``n_samples``. +For this data, we might want to apply a :class:`OneHotEncoder` to the +``'city'`` column, but a :class:`CountVectorizer` to the ``'title'`` column. +As we might use multiple feature extraction methods on the same column, we give each +transformer a unique name, say ``'city_category'`` and ``'title_bow'``:: + + >>> from sklearn.feature_extraction import ColumnTransformer + >>> from sklearn.preprocessing import OneHotEncoder + >>> from sklearn.feature_extraction.text import CountVectorizer + >>> column_trans = ColumnTransformer({'city_category': (CountVectorizer(analyzer=lambda x: [x]), 'city'), + ... 'title_bow': (CountVectorizer(), 'title')}) + + >>> column_trans.fit(X) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS + ColumnTransformer(n_jobs=1, transformer_weights=None, + transformers=...) + + >>> column_trans.get_feature_names() == [ + ... 'city_category__London', 'city_category__New York', 'city_category__Paris', + ... 'title_bow__bow', 'title_bow__feast', 'title_bow__gatsby', + ... 'title_bow__great', 'title_bow__his', 'title_bow__how', 'title_bow__last', + ... 'title_bow__learned', 'title_bow__moveable', 'title_bow__the', + ... 'title_bow__trick', 'title_bow__watson'] + True + + >>> column_trans.transform(X).toarray() # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS + array([[1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1], + [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0]]...) .. _feature_hashing: diff --git a/doc/modules/pipeline.rst b/doc/modules/pipeline.rst index 61a0e318da5b8..6b3b4f93da885 100644 --- a/doc/modules/pipeline.rst +++ b/doc/modules/pipeline.rst @@ -122,9 +122,11 @@ FeatureUnion: composite feature spaces :class:`FeatureUnion` combines several transformer objects into a new transformer that combines their output. A :class:`FeatureUnion` takes a list of transformer objects. During fitting, each of these -is fit to the data independently. For transforming data, the -transformers are applied in parallel, and the sample vectors they output -are concatenated end-to-end into larger vectors. +is fit to the data independently. It can also be used to apply different +transformations to each field of the data, producing a homogeneous feature +matrix from a heterogeneous data source. +The transformers are applied in parallel, and the feature matrices they output +are concatenated side-by-side into a larger matrix. :class:`FeatureUnion` serves the same purposes as :class:`Pipeline` - convenience and joint parameter estimation and validation. @@ -166,4 +168,4 @@ Like pipelines, feature unions have a shorthand constructor called .. topic:: Examples: * :ref:`example_feature_stacker.py` - * :ref:`example_hetero_feature_union.py` + * :ref:`example_hetero_feature_union.py` illustrates the ``fields`` parameter. diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 3ab7946e4f05a..3e2bb70b43f38 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -217,6 +217,9 @@ New features By `Alexandre Gramfort`_, `Jan Hendrik Metzen`_, `Mathieu Blondel`_ and `Balazs Kegl`_. + - :class:`pipeline.FeatureUnion` now allows the extraction of particular + features from dictionaries or pandas dataframes via the ``fields`` + parameter. By `Andreas Müller`_. Enhancements ............ diff --git a/examples/hetero_feature_union.py b/examples/column_transformer.py similarity index 68% rename from examples/hetero_feature_union.py rename to examples/column_transformer.py index 9a2c6742f3946..912c8b538a8ea 100644 --- a/examples/hetero_feature_union.py +++ b/examples/column_transformer.py @@ -38,48 +38,9 @@ from sklearn.feature_extraction import DictVectorizer from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics import classification_report -from sklearn.pipeline import FeatureUnion +from sklearn.feature_extraction import ColumnTransformer from sklearn.pipeline import Pipeline -from sklearn.svm import SVC - - -class ItemSelector(BaseEstimator, TransformerMixin): - """For data grouped by feature, select subset of data at a provided key. - - The data is expected to be stored in a 2D data structure, where the first - index is over features and the second is over samples. i.e. - - >> len(data[key]) == n_samples - - Please note that this is the opposite convention to sklearn feature - matrixes (where the first index corresponds to sample). - - ItemSelector only requires that the collection implement getitem - (data[key]). Examples include: a dict of lists, 2D numpy array, Pandas - DataFrame, numpy record array, etc. - - >> data = {'a': [1, 5, 2, 5, 2, 8], - 'b': [9, 4, 1, 4, 1, 3]} - >> ds = ItemSelector(key='a') - >> data['a'] == ds.transform(data) - - ItemSelector is not designed to handle data grouped by sample. (e.g. a - list of dicts). If your data is structured this way, consider a - transformer along the lines of `sklearn.feature_extraction.DictVectorizer`. - - Parameters - ---------- - key : hashable, required - The key corresponding to the desired value in a mappable. - """ - def __init__(self, key): - self.key = key - - def fit(self, x, y=None): - return self - - def transform(self, data_dict): - return data_dict[self.key] +from sklearn.svm import LinearSVC class TextStats(BaseEstimator, TransformerMixin): @@ -128,41 +89,34 @@ def transform(self, posts): ('subjectbody', SubjectBodyExtractor()), # Use FeatureUnion to combine the features from subject and body - ('union', FeatureUnion( - transformer_list=[ - - # Pipeline for pulling features from the post's subject line - ('subject', Pipeline([ - ('selector', ItemSelector(key='subject')), - ('tfidf', TfidfVectorizer(min_df=50)), - ])), + ('union', ColumnTransformer( + { + # Pulling features from the post's subject line + 'subject': (TfidfVectorizer(min_df=50), 'subject'), # Pipeline for standard bag-of-words model for body - ('body_bow', Pipeline([ - ('selector', ItemSelector(key='body')), + 'body_bow': (Pipeline([ ('tfidf', TfidfVectorizer()), ('best', TruncatedSVD(n_components=50)), - ])), + ]), 'body'), # Pipeline for pulling ad hoc features from post's body - ('body_stats', Pipeline([ - ('selector', ItemSelector(key='body')), + 'body_stats': (Pipeline([ ('stats', TextStats()), # returns a list of dicts ('vect', DictVectorizer()), # list of dicts -> feature matrix - ])), - - ], + ]), 'body'), + }, # weight components in FeatureUnion transformer_weights={ 'subject': 0.8, 'body_bow': 0.5, 'body_stats': 1.0, - }, + } )), # Use a SVC classifier on the combined features - ('svc', SVC(kernel='linear')), + ('svc', LinearSVC(dual=False)), ]) # limit the list of categories to make running this exmaple faster. diff --git a/sklearn/feature_extraction/__init__.py b/sklearn/feature_extraction/__init__.py index b45440444d769..95bc96aa521fb 100644 --- a/sklearn/feature_extraction/__init__.py +++ b/sklearn/feature_extraction/__init__.py @@ -5,9 +5,10 @@ """ from .dict_vectorizer import DictVectorizer +from .heterogeneous import ColumnTransformer from .hashing import FeatureHasher from .image import img_to_graph, grid_to_graph from . import text __all__ = ['DictVectorizer', 'image', 'img_to_graph', 'grid_to_graph', 'text', - 'FeatureHasher'] + 'FeatureHasher', 'ColumnTransformer'] diff --git a/sklearn/feature_extraction/heterogeneous.py b/sklearn/feature_extraction/heterogeneous.py new file mode 100644 index 0000000000000..04b5e6139b191 --- /dev/null +++ b/sklearn/feature_extraction/heterogeneous.py @@ -0,0 +1,148 @@ +from scipy import sparse +import numpy as np + +from ..base import BaseEstimator, TransformerMixin +from ..pipeline import _fit_one_transformer, _fit_transform_one, _transform_one +from ..externals.joblib import Parallel, delayed +from ..externals.six import iteritems + + +class ColumnTransformer(BaseEstimator, TransformerMixin): + """Applies transformers to columns of a dataframe / dict. + + This estimator applies transformer objects to columns or fields of the + input, then concatenates the results. This is useful for heterogeneous or + columnar data, to combine several feature extraction mechanisms into a + single transformer. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + transformers : dict from string to (string, transformer) tuples + Keys are arbitrary names, values are tuples of column names and + transformer objects. + + n_jobs : int, optional + Number of jobs to run in parallel (default 1). + + transformer_weights : dict, optional + Multiplicative weights for features per transformer. + Keys are transformer names, values the weights. + + Examples + -------- + >>> from sklearn.preprocessing import Normalizer + >>> union = ColumnTransformer({"norm1": (Normalizer(norm='l1'), 'subset1'), \ + "norm2": (Normalizer(norm='l1'), 'subset2')}) + >>> X = {'subset1': [[0., 1.], [2., 2.]], 'subset2': [[1., 1.], [0., 1.]]} + >>> union.fit_transform(X) # doctest: +NORMALIZE_WHITESPACE + array([[ 0. , 1. , 0.5, 0.5], + [ 0.5, 0.5, 0. , 1. ]]) + + """ + def __init__(self, transformers, n_jobs=1, transformer_weights=None): + self.transformers = transformers + self.n_jobs = n_jobs + self.transformer_weights = transformer_weights + + def get_feature_names(self): + """Get feature names from all transformers. + + Returns + ------- + feature_names : list of strings + Names of the features produced by transform. + """ + feature_names = [] + for name, (trans, column) in sorted(self.transformers.items()): + if not hasattr(trans, 'get_feature_names'): + raise AttributeError("Transformer %s does not provide" + " get_feature_names." % str(name)) + feature_names.extend([name + "__" + f for f in + trans.get_feature_names()]) + return feature_names + + def get_params(self, deep=True): + if not deep: + return super(ColumnTransformer, self).get_params(deep=False) + else: + out = dict(self.transformers) + for name, (trans, _) in self.transformers.items(): + for key, value in iteritems(trans.get_params(deep=True)): + out['%s__%s' % (name, key)] = value + out.update(super(ColumnTransformer, self).get_params(deep=False)) + return out + + def fit(self, X, y=None): + """Fit all transformers using X. + + Parameters + ---------- + X : array-like or sparse matrix, shape (n_samples, n_features) + Input data, used to fit transformers. + """ + transformers = Parallel(n_jobs=self.n_jobs)( + delayed(_fit_one_transformer)(trans, X[column], y) + for name, (trans, column) in sorted(self.transformers.items())) + self._update_transformers(transformers) + return self + + def fit_transform(self, X, y=None, **fit_params): + """Fit all transformers using X, transform the data and concatenate + results. + + Parameters + ---------- + X : array-like or sparse matrix, shape (n_samples, n_features) + Input data to be transformed. + + Returns + ------- + X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) + hstack of results of transformers. sum_n_components is the + sum of n_components (output dimension) over transformers. + """ + result = Parallel(n_jobs=self.n_jobs)( + delayed(_fit_transform_one)(trans, name, X[column], y, + self.transformer_weights, + **fit_params) + for name, (trans, column) in sorted(self.transformers.items())) + + Xs, transformers = zip(*result) + self._update_transformers(transformers) + if any(sparse.issparse(f) for f in Xs): + Xs = sparse.hstack(Xs).tocsr() + else: + Xs = np.hstack(Xs) + return Xs + + def transform(self, X): + """Transform X separately by each transformer, concatenate results. + + Parameters + ---------- + X : array-like or sparse matrix, shape (n_samples, n_features) + Input data to be transformed. + + Returns + ------- + X_t : array-like or sparse matrix, shape (n_samples, sum_n_components) + hstack of results of transformers. sum_n_components is the + sum of n_components (output dimension) over transformers. + """ + Xs = Parallel(n_jobs=self.n_jobs)( + delayed(_transform_one)(trans, name, X[column], self.transformer_weights) + for name, (trans, column) in sorted(self.transformers.items())) + if any(sparse.issparse(f) for f in Xs): + Xs = sparse.hstack(Xs).tocsr() + else: + Xs = np.hstack(Xs) + return Xs + + def _update_transformers(self, transformers): + # use a dict constructor instaed of a dict comprehension for python2.6 + self.transformers.update(dict( + (name, (new, column)) + for ((name, (old, column)), new) in zip(sorted(self.transformers.items()), transformers)) + ) diff --git a/sklearn/feature_extraction/tests/test_column_transformer.py b/sklearn/feature_extraction/tests/test_column_transformer.py new file mode 100644 index 0000000000000..c5347e73975ea --- /dev/null +++ b/sklearn/feature_extraction/tests/test_column_transformer.py @@ -0,0 +1,85 @@ +import numpy as np +import scipy.sparse as sp + +from sklearn.base import BaseEstimator +from sklearn.feature_extraction import ColumnTransformer + +from sklearn.utils.testing import assert_array_equal, assert_equal, assert_true +from sklearn.utils.validation import check_array + + +class Trans(BaseEstimator): + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + return check_array(X).reshape(-1, 1) + + +class SparseMatrixTrans(BaseEstimator): + def fit(self, X, y=None): + return self + + def transform(self, X, y=None): + n_samples = len(X) + return sp.eye(n_samples, n_samples).tocsr() + + +def test_column_selection(): + # dictionary + X_dict = {'first': [0, 1, 2], + 'second': [2, 4, 6]} + # recarray + X_recarray = np.recarray((3,), + dtype=[('first', np.int), ('second', np.int)]) + X_recarray['first'] = X_dict['first'] + X_recarray['second'] = X_dict['second'] + Xs = [X_dict, X_recarray] + + try: + import pandas as pd + X_df = pd.DataFrame(X_dict) + Xs.append(X_df) + except: + print("Pandas not found, not testing ColumnTransformer with" + " DataFrame.") + X_res_first = np.array(X_dict['first']).reshape(-1, 1) + X_res_second = np.array(X_dict['second']).reshape(-1, 1) + X_res_both = np.vstack([X_dict['first'], X_dict['second']]).T + + for X in Xs: + first_feat = ColumnTransformer({'trans': (Trans(), 'first')}) + second_feat = ColumnTransformer({'trans': (Trans(), 'second')}) + both = ColumnTransformer({'trans1': (Trans(), 'first'), + 'trans2': (Trans(), 'second')}) + assert_array_equal(first_feat.fit_transform(X), X_res_first) + assert_array_equal(second_feat.fit_transform(X), X_res_second) + assert_array_equal(both.fit_transform(X), X_res_both) + # fit then transform + assert_array_equal(first_feat.fit(X).transform(X), X_res_first) + assert_array_equal(second_feat.fit(X).transform(X), X_res_second) + assert_array_equal(both.fit(X).transform(X), X_res_both) + + # test with transformer_weights + transformer_weights = {'trans1': .1, 'trans2': 10} + for X in Xs: + both = ColumnTransformer({'trans1': (Trans(), 'first'), + 'trans2': (Trans(), 'second')}, + transformer_weights=transformer_weights) + res = np.vstack([transformer_weights['trans1'] * np.array(X['first']), + transformer_weights['trans2'] * np.array(X['second'])]).T + assert_array_equal(both.fit_transform(X), res) + # fit then transform + assert_array_equal(both.fit(X).transform(X), res) + + +def test_sparse_stacking(): + X_dict = {'first': [0, 1, 2], + 'second': [2, 4, 6]} + col_trans = ColumnTransformer({'trans1': (Trans(), 'first'), 'trans2': + (SparseMatrixTrans(), 'second')}) + col_trans.fit(X_dict) + X_trans = col_trans.transform(X_dict) + assert_true(sp.issparse(X_trans)) + assert_equal(X_trans.shape, (X_trans.shape[0], X_trans.shape[0] + 1)) + assert_array_equal(X_trans.toarray()[:, 1:], np.eye(X_trans.shape[0])) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 85fedd926e947..94dcd2803f192 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -425,17 +425,26 @@ class FeatureUnion(BaseEstimator, TransformerMixin): Parameters ---------- - transformer_list: list of (string, transformer) tuples + transformer_list : list of (string, transformer) tuples List of transformer objects to be applied to the data. The first half of each tuple is the name of the transformer. - n_jobs: int, optional + n_jobs : int, optional Number of jobs to run in parallel (default 1). - transformer_weights: dict, optional + transformer_weights : dict, optional Multiplicative weights for features per transformer. Keys are transformer names, values the weights. + Examples + -------- + >>> from sklearn.decomposition import PCA, TruncatedSVD + >>> union = FeatureUnion([("pca", PCA()), \ + ("svd", TruncatedSVD())]) + >>> X = [[0., 1., 3], [2., 2., 5]] + >>> union.fit_transform(X) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS + array([[-1.5 , 0. , 3.0..., -0.8...], + [ 1.5 , 0. , 5.7..., 0.4...]]) """ def __init__(self, transformer_list, n_jobs=1, transformer_weights=None): self.transformer_list = transformer_list diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index b1075e73f0791..4e96c3ed8efa1 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -89,8 +89,7 @@ def test_pipeline_init(): pipe = Pipeline([('svc', clf)]) assert_equal(pipe.get_params(deep=True), dict(svc__a=None, svc__b=None, svc=clf, - **pipe.get_params(deep=False) - )) + **pipe.get_params(deep=False))) # Check that params are set pipe.set_params(svc__a=0.1) @@ -123,13 +122,13 @@ def test_pipeline_init(): # Check that apart from estimators, the parameters are the same params = pipe.get_params(deep=True) params2 = pipe2.get_params(deep=True) - + for x in pipe.get_params(deep=False): params.pop(x) - + for x in pipe2.get_params(deep=False): params2.pop(x) - + # Remove estimators that where copied params.pop('svc') params.pop('anova') diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index bcb1773399b5d..22e53fd1fd25d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1,11 +1,11 @@ """ The :mod:`sklearn.utils` module includes various utilities. """ -from collections import Sequence +import warnings +from collections import Sequence, Mapping import numpy as np from scipy.sparse import issparse -import warnings from .murmurhash import murmurhash3_32 from .validation import (as_float_array, @@ -146,7 +146,14 @@ def safe_indexing(X, indices): indices : array-like, list Indices according to which X will be subsampled. """ - if hasattr(X, "iloc"): + if X is None: + # fall-through + return None + elif isinstance(X, Mapping): + # slice per value + return dict([(k, safe_indexing(v, indices)) for k, v in + X.items()]) + elif hasattr(X, "iloc"): # Pandas Dataframes and Series try: return X.iloc[indices] diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 512e4fcace7a2..8ca88b86b5dde 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -1327,6 +1327,9 @@ def fit(self, X, y): if name in ('FeatureUnion', 'Pipeline'): e = estimator([('clf', T())]) + elif name == 'ColumnTransformer': + e = estimator({'clf': (T(), 'some_column')}) + elif name in ('GridSearchCV' 'RandomizedSearchCV'): return diff --git a/sklearn/utils/testing.py b/sklearn/utils/testing.py index 157243121bfe5..cc035ce4cbaaf 100644 --- a/sklearn/utils/testing.py +++ b/sklearn/utils/testing.py @@ -517,7 +517,7 @@ def uninstall_mldata_mock(): "RFECV", "BaseEnsemble"] # estimators that there is no way to default-construct sensibly OTHER = ["Pipeline", "FeatureUnion", "GridSearchCV", - "RandomizedSearchCV"] + "RandomizedSearchCV", "ColumnTransformer"] # some trange ones DONT_TEST = ['SparseCoder', 'EllipticEnvelope', 'DictVectorizer', diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index dc8d938a8d077..2df925a3562b3 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -8,6 +8,7 @@ # License: BSD 3 clause import warnings import numbers +from collections import Mapping import numpy as np import scipy.sparse as sp @@ -110,6 +111,15 @@ def _num_samples(x): # Don't get num_samples from an ensembles length! raise TypeError('Expected sequence or array-like, got ' 'estimator %s' % x) + + if isinstance(x, Mapping) and not sp.issparse(x): + n_samples = [_num_samples(xx) for xx in x.values()] + unique_samples = np.unique(n_samples) + if len(unique_samples) > 1: + raise ValueError("Inconsistent number of samples in dictionary: %s" + % (unique_samples)) + return n_samples[0] + if not hasattr(x, '__len__') and not hasattr(x, 'shape'): if hasattr(x, '__array__'): x = np.asarray(x) @@ -193,6 +203,7 @@ def indexable(*iterables): if sp.issparse(X): result.append(X.tocsr()) elif hasattr(X, "__getitem__") or hasattr(X, "iloc"): + # np.array, tuple, list, dict or pandas data frame result.append(X) elif X is None: result.append(X)