8000 MNT do not call fit twice in TransformedTargetetRegressor (#11641) · scikit-learn/scikit-learn@c5075f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit c5075f5

Browse files
glemaitreqinhanmin2014
authored andcommitted
MNT do not call fit twice in TransformedTargetetRegressor (#11641)
1 parent 62d2059 commit c5075f5

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

sklearn/compose/_target.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def __init__(self, regressor=None, transformer=None,
113113
self.check_inverse = check_inverse
114114

115115
def _fit_transformer(self, y):
116+
"""Check transformer and fit transformer.
117+
118+
Create the default transformer, fit it and make additional inverse
119+
check on a subset (optional).
120+
121+
"""
116122
if (self.transformer is not None and
117123
(self.func is not None or self.inverse_func is not None)):
118124
raise ValueError("'transformer' and functions 'func'/"
@@ -177,19 +183,20 @@ def fit(self, X, y, sample_weight=None):
177183
y_2d = y
178184
self._fit_transformer(y_2d)
179185

180-
if self.regressor is None:
181-
from ..linear_model import LinearRegression
182-
self.regressor_ = LinearRegression()
183-
else:
184-
self.regressor_ = clone(self.regressor)
185-
186186
# transform y and convert back to 1d array if needed
187-
y_trans = self.transformer_.fit_transform(y_2d)
187+
y_trans = self.transformer_.transform(y_2d)
188188
# FIXME: a FunctionTransformer can return a 1D array even when validate
189189
# is set to True. Therefore, we need to check the number of dimension
190190
# first.
191191
if y_trans.ndim == 2 and y_trans.shape[1] == 1:
192192
y_trans = y_trans.squeeze(axis=1)
193+
194+
if self.regressor is None:
195+
from ..linear_model import LinearRegression
196+
self.regressor_ = LinearRegression()
197+
else:
198+
self.regressor_ = clone(self.regressor)
199+
193200
if sample_weight is None:
194201
self.regressor_.fit(X, y_trans)
195202
else:

sklearn/compose/tests/test_target.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
265265
tt.predict(X.tolist())
266266
assert_raises(AssertionError, tt.fit, X, y.tolist())
267267
assert_raises(AssertionError, tt.predict, X)
268+
269+
270+
class DummyTransformer(BaseEstimator, TransformerMixin):
271+
"""Dummy transformer which count how many time fit was called."""
272+
def __init__(self, fit_counter=0):
273+
self.fit_counter = fit_counter
274+
275+
def fit(self, X, y=None):
276+
self.fit_counter += 1
277+
return self
278+
279+
def transform(self, X):
280+
return X
281+
282+
def inverse_transform(self, X):
283+
return X
284+
285+
286+
@pytest.mark.parametrize("check_inverse", [False, True])
287+
def test_transform_target_regressor_count_fit(check_inverse):
288+
# regression test for gh-issue #11618
289+
# check that we only call a single time fit for the transformer
290+
X, y = friedman
291+
ttr = TransformedTargetRegressor(
292+
transformer=DummyTransformer(), check_inverse=check_inverse
293+
)
294+
ttr.fit(X, y)
295+
assert ttr.transformer_.fit_counter == 1

0 commit comments

Comments
 (0)
0