8000 FIX Fixes transform wrappig in _SetOutputMixin (#25295) · Anthony22-dev/scikit-learn@a0a6ea7 · GitHub
[go: up one dir, main page]

Skip to content

Commit a0a6ea7

Browse files
authored
FIX Fixes transform wrappig in _SetOutputMixin (scikit-learn#25295)
Fixes scikit-learn#25293
1 parent 44bf2ab commit a0a6ea7

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

doc/whats_new/v1.2.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ Changelog
1919
certain estimators to be pickled when using Python 3.11. :pr:`25188` by
2020
:user:`Benjamin Bossan <BenjaminBossan>`.
2121

22+
- |Fix| Inheriting from :class:`base.TransformerMixin` will only wrap the `transform`
23+
method if the class defines `transform` itself. :pr:`25295` by `Thomas Fan`_.
24+
2225
:mod:`sklearn.linear_model`
2326
...........................
2427

sklearn/utils/_set_output.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ def __init_subclass__(cls, auto_wrap_output_keys=("transform",), **kwargs):
200200
if not hasattr(cls, method) or key not in auto_wrap_output_keys:
201201
continue
202202
cls._sklearn_auto_wrap_output_keys.add(key)
203+
204+
# Only wrap methods defined by cls itself
205+
if method not in cls.__dict__:
206+
continue
203207
wrapped_method = _wrap_method_output(getattr(cls, method), key)
204208
setattr(cls, method, wrapped_method)
205209

sklearn/utils/tests/test_set_output.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,3 +237,26 @@ def get_columns():
237237
X_np = np.asarray([[1, 3], [2, 4], [3, 5]])
238238
X_wrapped = _wrap_in_pandas_container(X_np, columns=get_columns)
239239
assert_array_equal(X_wrapped.columns, range(X_np.shape[1]))
240+
241+
242+
def test_set_output_mro():
243+
"""Check that multi-inheritance resolves to the correct class method.
244+
245+
Non-regression test gh-25293.
246+
"""
247+
248+
class Base(_SetOutputMixin):
249+
def transform(self, X):
250+
return "Base" # noqa
251+
252+
class A(Base):
253+
pass
254+
255+
class B(Base):
256+
def transform(self, X):
257+
return "B"
258+
259+
class C(A, B):
260+
pass
261+
262+
assert C().transform(None) == "B"

0 commit comments

Comments
 (0)
0