E56F FIX TransformerMixin does not override index if transform=pandas (#25… · thomasjpfan/scikit-learn@cc8228e · GitHub
[go: up one dir, main page]

Skip to content

Commit cc8228e

Browse files
authored
FIX TransformerMixin does not override index if transform=pandas (scikit-learn#25747)
1 parent 7cc5fae commit cc8228e

File tree

3 files changed

+41
-4
lines changed

3 files changed

+41
-4
lines changed

doc/whats_new/v1.2.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,13 @@ Version 1.2.2
1212
Changelog
1313
---------
1414

15+
:mod:`sklearn.base`
16+
...................
17+
18+
- |Fix| When `set_output(transform="pandas")`, :class:`base.TransformerMixin` maintains
19+
the index if the :term:`transform` output is already a DataFrame. :pr:`25747` by
20+
`Thomas Fan`_.
21+
1522
:mod:`sklearn.calibration`
1623
..........................
1724

sklearn/utils/_set_output.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _wrap_in_pandas_container(
3434
`range(n_features)`.
3535
3636
index : array-like, default=None
37-
Index for data.
37+
Index for data. `index` is ignored if `data_to_wrap` is already a DataFrame.
3838
3939
Returns
4040
-------
@@ -55,8 +55,6 @@ def _wrap_in_pandas_container(
5555
if isinstance(data_to_wrap, pd.DataFrame):
5656
if columns is not None:
5757
data_to_wrap.columns = columns
58-
if index is not None:
59-
data_to_wrap.index = index
6058
return data_to_wrap
6159

6260
return pd.DataFrame(data_to_wrap, index=index, columns=columns)

sklearn/utils/tests/test_set_output.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ def test__wrap_in_pandas_container_dense_update_columns_and_index():
3333

3434
new_df = _wrap_in_pandas_container(X_df, columns=new_columns, index=new_index)
3535
assert_array_equal(new_df.columns, new_columns)
36-
assert_array_equal(new_df.index, new_index)
36+
37+
# Index does not change when the input is a DataFrame
38+
assert_array_equal(new_df.index, X_df.index)
3739

3840

3941
def test__wrap_in_pandas_container_error_validation():
@@ -260,3 +262,33 @@ class C(A, B):
260262
pass
261263

262264
assert C().transform(None) == "B"
265+
266+
267+
class EstimatorWithSetOutputIndex(_SetOutputMixin):
268+
def fit(self, X, y=None):
269+
self.n_features_in_ = X.shape[1]
270+
return self
271+
272+
def transform(self, X, y=None):
273+
import pandas as pd
274+
275+
# transform by giving output a new index.
276+
return pd.DataFrame(X.to_numpy(), index=[f"s{i}" for i in range(X.shape[0])])
277+
278+
def get_feature_names_out(self, input_features=None):
279+
return np.asarray([f"X{i}" for i in range(self.n_features_in_)], dtype=object)
280+
281+
282+
def test_set_output_pandas_keep_index():
283+
"""Check that set_output does not override index.
284+
285+
Non-regression test for gh-25730.
286+
"""
287+
pd = pytest.importorskip("pandas")
288+
289+
X = pd.DataFrame([[1, 2, 3], [4, 5, 6]], index=[0, 1])
290+
est = EstimatorWithSetOutputIndex().set_output(transform="pandas")
291+
est.fit(X)
292+
293+
X_trans = est.transform(X)
294+
assert_array_equal(X_trans.index, ["s0", "s1"])

0 commit comments

Comments
 (0)
0