8000 FIX Keeps namedtuple's class when transform returns a tuple (#26121) · Veghit/scikit-learn@6d649ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 6d649ba

Browse files
thomasjpfanItay
authored andcommitted
FIX Keeps namedtuple's class when transform returns a tuple (scikit-learn#26121)
1 parent 73e0115 commit 6d649ba

File tree

3 files changed

+30
-1
lines changed

3 files changed

+30
-1
lines changed

doc/whats_new/v1.3.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ Changelog
166166
- |Feature| A `__sklearn_clone__` protocol is now available to override the
167167
default behavior of :func:`base.clone`. :pr:`24568` by `Thomas Fan`_.
168168

169+
- |Fix| :class:`base.TransformerMixin` now currently keeps a namedtuple's class
170+
if `transform` returns a namedtuple. :pr:`26121` by `Thomas Fan`_.
171+
169172
:mod:`sklearn.calibration`
170173
..........................
171174

sklearn/utils/_set_output.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,10 +140,15 @@ def wrapped(self, X, *args, **kwargs):
140140
data_to_wrap = f(self, X, *args, **kwargs)
141141
if isinstance(data_to_wrap, tuple):
142142
# only wrap the first output for cross decomposition
143-
return (
143+
return_tuple = (
144144
_wrap_data_with_container(method, data_to_wrap[0], X, self),
145145
*data_to_wrap[1:],
146146
)
147+
# Support for namedtuples `_make` is a documented API for namedtuples:
148+
# https://docs.python.org/3/library/collections.html#collections.somenamedtuple._make
149+
if hasattr(type(data_to_wrap), "_make"):
150+
return type(data_to_wrap)._make(return_tuple)
151+
return return_tuple
147152

148153
return _wrap_data_with_container(method, data_to_wrap, X, self)
149154

sklearn/utils/tests/test_set_output.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from collections import namedtuple
23

34
import numpy as np
45
from scipy.sparse import csr_matrix
@@ -292,3 +293,23 @@ def test_set_output_pandas_keep_index():
292293

293294
X_trans = est.transform(X)
294295
assert_array_equal(X_trans.index, ["s0", "s1"])
296+
297+
298+
class EstimatorReturnTuple(_SetOutputMixin):
299+
def __init__(self, OutputTuple):
300+
self.OutputTuple = OutputTuple
301+
302+
def transform(self, X, y=None):
303+
return self.OutputTuple(X, 2 * X)
304+
305+
306+
def test_set_output_named_tuple_out():
307+
"""Check that namedtuples are kept by default."""
308+
Output = namedtuple("Output", "X, Y")
309+
X = np.asarray([[1, 2, 3]])
310+
est = EstimatorReturnTuple(OutputTuple=Output)
311+
X_trans = est.transform(X)
312+
313+
assert isinstance(X_trans, Output)
314+
assert_array_equal(X_trans.X, X)
315+
assert_array_equal(X_trans.Y, 2 * X)

0 commit comments

Comments
 (0)
0