@@ -257,18 +257,32 @@ def test_column_transformer_dataframe():
257
257
# ensure pandas object is passed through
258
258
259
259
class TransAssert (BaseEstimator ):
260
+ def __init__ (self , expected_type_transform ):
261
+ self .expected_type_transform = expected_type_transform
262
+
260
263
def fit (self , X , y = None ):
261
264
return self
262
265
263
266
def transform (self , X , y = None ):
264
- assert isinstance (X , ( pd . DataFrame , pd . Series ) )
267
+ assert isinstance (X , self . expected_type_transform )
265
268
if isinstance (X , pd .Series ):
266
269
X = X .to_frame ()
267
270
return X
268
271
269
- ct = ColumnTransformer ([("trans" , TransAssert (), "first" )], remainder = "drop" )
272
+ ct = ColumnTransformer (
273
+ [("trans" , TransAssert (expected_type_transform = pd .Series ), "first" )],
274
+ remainder = "drop" ,
275
+ )
270
276
ct .fit_transform (X_df )
271
- ct = ColumnTransformer ([("trans" , TransAssert (), ["first" , "second" ])])
277
+ ct = ColumnTransformer (
278
+ [
279
+ (
280
+ "trans" ,
281
+ TransAssert (expected_type_transform = pd .DataFrame ),
282
+ ["first" , "second" ],
283
+ )
284
+ ]
285
+ )
272
286
ct .fit_transform (X_df )
273
287
274
288
# integer column spec + integer column names -> still use positional
0 commit comments