@@ -668,25 +668,88 @@ def test_column_transformer_get_feature_names():
668
668
ct .fit (X )
669
669
assert ct .get_feature_names () == ['col0__a' , 'col0__b' , 'col1__c' ]
670
670
671
- # passthrough transformers not supported
671
+ # drop transformer
672
+ ct = ColumnTransformer (
673
+ [('col0' , DictVectorizer (), 0 ), ('col1' , 'drop' , 1 )])
674
+ ct .fit (X )
675
+ assert ct .get_feature_names () == ['col0__a' , 'col0__b' ]
676
+
677
+ # passthrough transformer
672
678
ct = ColumnTransformer ([('trans' , 'passthrough' , [0 , 1 ])])
673
679
ct .fit (X )
674
- assert_raise_message (
675
- NotImplementedError , 'get_feature_names is not yet supported' ,
676
- ct .get_feature_names )
680
+ assert ct .get_feature_names () == ['x0' , 'x1' ]
677
681
678
682
ct = ColumnTransformer ([('trans' , DictVectorizer (), 0 )],
679
683
remainder = 'passthrough' )
680
684
ct .fit (X )
681
- assert_raise_message (
682
- NotImplementedError , 'get_feature_names is not yet supported' ,
683
- ct .get_feature_names )
685
+ assert ct .get_feature_names () == ['trans__a' , 'trans__b' , 'x1' ]
684
686
685
- # drop transformer
686
- ct = ColumnTransformer (
687
- [('col0' , DictVectorizer (), 0 ), ('col1' , 'drop' , 1 )])
687
+ ct = ColumnTransformer ([('trans' , 'passthrough' , [1 ])],
688
+ remainder = 'passthrough' )
688
689
ct .fit (X )
689
- assert ct .get_feature_names () == ['col0__a' , 'col0__b' ]
690
+ assert ct .get_feature_names () == ['x1' , 'x0' ]
691
+
692
+ ct = ColumnTransformer ([('trans' , 'passthrough' , lambda x : [1 ])],
693
+ remainder = 'passthrough' )
694
+ ct .fit (X )
695
+ assert ct .get_feature_names () == ['x1' , 'x0' ]
696
+
697
+ ct = ColumnTransformer ([('trans' , 'passthrough' , np .array ([False , True ]))],
698
+ remainder = 'passthrough' )
699
+ ct .fit (X )
700
+ assert ct .get_feature_names () == ['x1' , 'x0' ]
701
+
702
+ ct = ColumnTransformer ([('trans' , 'passthrough' , slice (1 , 2 ))],
703
+ remainder = 'passthrough' )
704
+ ct .fit (X )
705
+ assert ct .get_feature_names () == ['x1' , 'x0' ]
706
+
707
+
708
+ def test_column_transformer_get_feature_names_dataframe ():
709
+ # passthough transformer with a dataframe
710
+ pd = pytest .importorskip ('pandas' )
711
+ X = np .array ([[{'a' : 1 , 'b' : 2 }, {'a' : 3 , 'b' : 4 }],
712
+ [{'c' : 5 }, {'c' : 6 }]], dtype = object ).T
713
+ X_df = pd .DataFrame (X , columns = ['col0' , 'col1' ])
714
+
715
+ ct = ColumnTransformer ([('trans' , 'passthrough' , ['col0' , 'col1' ])])
716
+ ct .fit (X_df )
717
+ assert ct .get_feature_names () == ['col0' , 'col1' ]
718
+
719
+ ct = ColumnTransformer ([('trans' , 'passthrough' , [0 , 1 ])])
720
+ ct .fit (X_df )
721
+ assert ct .get_feature_names () == ['col0' , 'col1' ]
722
+
723
+ ct = ColumnTransformer ([('col0' , DictVectorizer (), 0 )],
724
+ remainder = 'passthrough' )
725
+ ct .fit (X_df )
726
+ assert ct .get_feature_names () == ['col0__a' , 'col0__b' , 'col1' ]
727
+
728
+ ct = ColumnTransformer ([('trans' , 'passthrough' , ['col1' ])],
729
+ remainder = 'passthrough' )
730
+ ct .fit (X_df )
731
+ assert ct .get_feature_names () == ['col1' , 'col0' ]
732
+
733
+ ct = ColumnTransformer ([('trans' , 'passthrough' ,
734
+ lambda x : x [['col1' ]].columns )],
735
+ remainder = 'passthrough' )
736
+ ct .fit (X_df )
737
+ assert ct .get_feature_names () == ['col1' , 'col0' ]
738
+
739
+ ct = ColumnTransformer ([('trans' , 'passthrough' , np .array ([False , True ]))],
740
+ remainder = 'passthrough' )
741
+ ct .fit (X_df )
742
+ assert ct .get_feature_names () == ['col1' , 'col0' ]
743
+
744
+ ct = ColumnTransformer ([('trans' , 'passthrough' , slice (1 , 2 ))],
745
+ remainder = 'passthrough' )
746
+ ct .fit (X_df )
747
+ assert ct .get_feature_names () == ['col1' , 'col0' ]
748
+
749
+ ct = ColumnTransformer ([('trans' , 'passthrough' , [1 ])],
750
+ remainder = 'passthrough' )
751
+ ct .fit (X_df )
752
+ assert ct .get_feature_names () == ['col1' , 'col0' ]
690
753
691
754
692
755
def test_column_transformer_special_strings ():
0 commit comments