diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index f562580973d44..016f9d95fe2e0 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -205,6 +205,21 @@ def _pairwise(self): # check if first estimator expects pairwise input return getattr(self.steps[0][1], '_pairwise', False) + def get_feature_names(self): + """Get feature names from the last step. + + Returns + ------- + feature_names : list of strings + Names of the features produced by transform. + """ + name, trans = self.steps[-1] + if not hasattr(trans, 'get_feature_names'): + raise AttributeError("Transformer %s does not provide" + " get_feature_names." % str(name)) + return trans.get_feature_names() + + def _fit_one_transformer(transformer, X, y): transformer.fit(X, y) diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index b40a0e5e9140d..53f369d723988 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -20,6 +20,7 @@ from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from sklearn.feature_extraction.text import CountVectorizer +from sklearn.feature_extraction import DictVectorizer class IncorrectT(BaseEstimator): @@ -301,3 +302,40 @@ def test_feature_union_feature_names(): for feat in feature_names: assert_true("chars__" in feat or "words__" in feat) assert_equal(len(feature_names), 35) + + +def test_feature_union_pipeline_feature_names(): + + JUNK_FOOD_DOCS = [ + {'vendor': 'JunkyPizza', 'available': False, 'text': 'the pizza burger'}, + {'vendor': 'FunkyPizza', 'available': True, 'text': 'the coke burger'} + ] + + class DocsPrepareTransformer(BaseEstimator): + KNOWN_VENDORS = set(['JunkyPizza']) + + def fit(self, X, y=None): + return self + def transform(self, X, y=None): + return [{ + 'vendor': doc['vendor'], + 'vendor_is_known': doc['vendor'] in self.KNOWN_VENDORS, + 'available': doc['available'] + } for doc in X] + + ft = FeatureUnion([ + ('text', CountVectorizer(preprocessor=lambda doc: doc['text'])), + ('attrs', Pipeline([ + ('prepare', DocsPrepareTransformer()), + ('vectorize', DictVectorizer()), + ])) + ]) + + ft.fit(JUNK_FOOD_DOCS) + assert_equal( + sorted(ft.get_feature_names()), + ['attrs__available', + 'attrs__vendor=FunkyPizza', 'attrs__vendor=JunkyPizza', + 'attrs__vendor_is_known', + 'text__burger', 'text__coke', 'text__pizza', 'text__the'] + )