8000 BUG: reset internal state of scaler before fitting · scikit-learn/scikit-learn@6eb1ddc · GitHub
[go: up one dir, main page]

Skip to content

Commit 6eb1ddc

Browse files
author
giorgiop
committed
BUG: reset internal state of scaler before fitting
1 parent 3301893 commit 6eb1ddc

File tree

2 files changed

+76
-0
lines changed

2 files changed

+76
-0
lines changed

sklearn/preprocessing/data.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,15 +252,36 @@ def data_range(self):
252252
def data_min(self):
253253
return self.data_min_
254254

255+
def _reset(self):
256+
"""Reset internal data-dependant state of the scaler, if necessary.
257+
258+
__init__ parameters are not touched.
259+
"""
260+
261+
# Checking one attribute is enough, becase they are all set together
262+
# in partial_fit
263+
if hasattr(self, 'scale_'):
264+
del self.scale_
265+
del self.min_
266+
del self.n_samples_seen_
267+
del self.data_min_
268+
del self.data_max_
269+
del self.data_range_
270+
255271
def fit(self, X, y=None):
256272
"""Compute the minimum and maximum to be used for later scaling.
257273
274+
It always
275+
258276
Parameters
259277
----------
260278
X : array-like, shape [n_samples, n_features]
261279
The data used to compute the per-feature minimum and maximum
262280
used for later scaling along the features axis.
263281
"""
282+
283+
# Reset internal state before fitting
284+
self._reset()
264285
return self.partial_fit(X, y)
265286

266287
def partial_fit(self, X, y=None):
@@ -489,6 +510,20 @@ def __init__(self, copy=True, with_mean=True, with_std=True):
489510
def std_(self):
490511
return self.scale_
491512

513+
def _reset(self):
514+
"""Reset internal data-dependant state of the scaler, if necessary.
515+
516+
__init__ parameters are not touched.
517+
"""
518+
519+
# Checking one attribute is enough, becase they are all set together
520+
# in partial_fit
521+
if hasattr(self, 'scale_'):
522+
del self.scale_
523+
del self.n_samples_seen_
524+
del self.mean_
525+
del self.var_
526+
492527
def fit(self, X, y=None):
493528
"""Compute the mean and std to be used for later scaling.
494529
@@ -500,6 +535,9 @@ def fit(self, X, y=None):
500535
501536
y: Passthrough for ``Pipeline`` compatibility.
502537
"""
538+
539+
# Reset internal state before fitting
540+
self._reset()
503541
return self.partial_fit(X, y)
504542

505543
def partial_fit(self, X, y=None):
@@ -671,6 +709,19 @@ class MaxAbsScaler(BaseEstimator, TransformerMixin):
671709
def __init__(self, copy=True):
672710
self.copy = copy
673711

712+
def _reset(self):
713+
"""Reset internal data-dependant state of the scaler, if necessary.
714+
715+
__init__ parameters are not touched.
716+
"""
717+
718+
# Checking one attribute is enough, becase they are all set together
719+
# in partial_fit
720+
if hasattr(self, 'scale_'):
721+
del self.scale_
722+
del self.n_samples_seen_
723+
del self.max_abs_
724+
674725
def fit(self, X, y=None):
675726
"""Compute the maximum absolute value to be used for later scaling.
676727
@@ -680,6 +731,9 @@ def fit(self, X, y=None):
680731
The data used to compute the per-feature minimum and maximum
681732
used for later scaling along the features axis.
682733
"""
734+
735+
# Reset internal state before fitting
736+
self._reset()
683737
return self.partial_fit(X, y)
684738

685739
def partial_fit(self, X, y=None):

sklearn/preprocessing/tests/test_data.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,3 +1498,25 @@ def test_one_hot_encoder_unknown_transform():
14981498
oh = OneHotEncoder(handle_unknown='42')
14991499
oh.fit(X)
15001500
assert_raises(ValueError, oh.transform, y)
1501+
1502+
1503+
def test_fit_cold_start():
1504+
X = iris.data
1505+
X_2d = X[:, :2]
1506+
1507+
# Scalers that have a partial_fit method
1508+
scalers = [StandardScaler(with_mean=False, with_std=False),
1509+
MinMaxScaler(),
1510+
MaxAbsScaler()]
1511+
1512+
for scaler in scalers:
1513+
scaler.fit_transform(X)
1514+
# with a different shape, this may break the scaler unless the internal
1515+
# state is re-set
1516+
try:
1517+
scaler.fit_transform(X_2d)
1518+
except ValueError as err:
1519+
print("Cannot fit %s a second time with different shape "
1520+
"of input. Error message: %s"
1521+
% (scaler.__class__.__name__, str(err)))
1522+
assert False

0 commit comments

Comments
 (0)
0