10000 TST Be more explicit in test_column_transformer_dataframe test (#26667) · scikit-learn/scikit-learn@8c054b6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 8c054b6

Browse files
authored
TST Be more explicit in test_column_transformer_dataframe test (#26667)
1 parent a9fa9a7 commit 8c054b6

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

sklearn/compose/tests/test_column_transformer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,18 +257,32 @@ def test_column_transformer_dataframe():
257257
# ensure pandas object is passed through
258258

259259
class TransAssert(BaseEstimator):
260+
def __init__(self, expected_type_transform):
261+
self.expected_type_transform = expected_type_transform
262+
260263
def fit(self, X, y=None):
261264
return self
262265

263266
def transform(self, X, y=None):
264-
assert isinstance(X, (pd.DataFrame, pd.Series))
267+
assert isinstance(X, self.expected_type_transform)
265268
if isinstance(X, pd.Series):
266269
X = X.to_frame()
267270
return X
268271

269-
ct = ColumnTransformer([("trans", TransAssert(), "first")], remainder="drop")
272+
ct = ColumnTransformer(
273+
[("trans", TransAssert(expected_type_transform=pd.Series), "first")],
274+
remainder="drop",
275+
)
270276
ct.fit_transform(X_df)
271-
ct = ColumnTransformer([("trans", TransAssert(), ["first", "second"])])
277+
ct = ColumnTransformer(
278+
[
279+
(
280+
"trans",
281+
TransAssert(expected_type_transform=pd.DataFrame),
282+
["first", "second"],
283+
)
284+
]
285+
)
272286
ct.fit_transform(X_df)
273287

274288
# integer column spec + integer column names -> still use positional

0 commit comments

Comments
 (0)
0