From 945c8d3dadc57dac58a3444df127e5e5418ab59f Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 22 Feb 2021 19:17:24 +0100 Subject: [PATCH 1/6] Fix scaler on near-constant features --- sklearn/linear_model/_base.py | 6 ++- sklearn/linear_model/tests/test_base.py | 24 ++++++--- sklearn/preprocessing/_data.py | 21 +++++--- sklearn/preprocessing/tests/test_data.py | 63 ++++++++++++++++++++++-- 4 files changed, 95 insertions(+), 19 deletions(-) diff --git a/sklearn/linear_model/_base.py b/sklearn/linear_model/_base.py index 61005cb4b5d4a..28cc386b4ecda 100644 --- a/sklearn/linear_model/_base.py +++ b/sklearn/linear_model/_base.py @@ -246,9 +246,13 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True, X_var = X_var.astype(X.dtype, copy=False) if normalize: + # Detect constant features on the computed variance, before taking + # the np.sqrt. Otherwise constant features cannot be detected with + # sample_weights. + constant_mask = X_var < 10 * np.finfo(X.dtype).eps X_var *= X.shape[0] X_scale = np.sqrt(X_var, out=X_var) - X_scale[X_scale < 10 * np.finfo(X_scale.dtype).eps] = 1. + X_scale[constant_mask] = 1. if sp.issparse(X): inplace_column_scale(X, 1. / X_scale) else: diff --git a/sklearn/linear_model/tests/test_base.py b/sklearn/linear_model/tests/test_base.py index 56ee18f5f0d06..bf7a2696fcda2 100644 --- a/sklearn/linear_model/tests/test_base.py +++ b/sklearn/linear_model/tests/test_base.py @@ -478,10 +478,8 @@ def test_preprocess_data_weighted(is_sparse): # better check the impact of feature scaling. X[:, 0] *= 10 - # Constant non-zero feature: this edge-case is currently not handled - # correctly for sparse data, see: - # https://github.com/scikit-learn/scikit-learn/issues/19450 - # X[:, 2] = 1. + # Constant non-zero feature. + X[:, 2] = 1. # Constant zero feature (non-materialized in the sparse case) X[:, 3] = 0. @@ -495,10 +493,12 @@ def test_preprocess_data_weighted(is_sparse): X_sample_weight_var = np.average((X - X_sample_weight_avg)**2, weights=sample_weight, axis=0) + constant_mask = X_sample_weight_var < 10 * np.finfo(X.dtype).eps + assert_array_equal(constant_mask, [0, 0, 1, 1]) expected_X_scale = np.sqrt(X_sample_weight_var) * np.sqrt(n_samples) # near constant features should not be scaled - expected_X_scale[expected_X_scale < 10 * np.finfo(np.float64).eps] = 1 + expected_X_scale[constant_mask] = 1 if is_sparse: X = sparse.csr_matrix(X) @@ -538,14 +538,22 @@ def test_preprocess_data_weighted(is_sparse): # _preprocess_data with normalize=True scales the data by the feature-wise # euclidean norms while StandardScaler scales the data by the feature-wise # standard deviations. - # The two are equivalent up to a ratio of np.sqrt(n_samples) + # The two are equivalent up to a ratio of np.sqrt(n_samples). if is_sparse: scaler = StandardScaler(with_mean=False).fit( X, sample_weight=sample_weight) + # Non-constant features are scaled similarly with np.sqrt(n_samples) assert_array_almost_equal( - scaler.transform(X).toarray() / np.sqrt(n_samples), Xt.toarray() - ) + scaler.transform(X).toarray()[:, :2] / np.sqrt(n_samples), + Xt.toarray()[:, :2] + ) + + # Constant features go through un-scaled. + assert_array_almost_equal( + scaler.transform(X).toarray()[:, 2:], + Xt.toarray()[:, 2:] + ) else: scaler = StandardScaler(with_mean=True).fit( X, sample_weight=sample_weight) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index 92a4135147b87..f8280a5374ac6 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -60,22 +60,27 @@ ] -def _handle_zeros_in_scale(scale, copy=True): +def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): """Makes sure that whenever scale is zero, we handle it correctly. This happens in most scalers when we have constant features. """ - # if we are fitting on 1D arrays, scale might be a scalar if np.isscalar(scale): if scale == .0: scale = 1. return scale elif isinstance(scale, np.ndarray): + if constant_mask is None: + # Detect near constant values to avoid dividing by a very small + # value that could lead to suprising results and numerical + # stability issues. + constant_mask = scale < 10 * np.finfo(scale.dtype).eps + if copy: # New array to avoid side-effects scale = scale.copy() - scale[scale == 0.0] = 1.0 + scale[constant_mask] = 1.0 return scale @@ -408,7 +413,7 @@ def partial_fit(self, X, y=None): data_range = data_max - data_min self.scale_ = ((feature_range[1] - feature_range[0]) / - _handle_zeros_in_scale(data_range)) + _handle_zeros_in_scale(data_range, copy=True)) self.min_ = feature_range[0] - data_min * self.scale_ self.data_min_ = data_min self.data_max_ = data_max @@ -850,7 +855,11 @@ def partial_fit(self, X, y=None, sample_weight=None): self.n_samples_seen_ = self.n_samples_seen_[0] if self.with_std: - self.scale_ = _handle_zeros_in_scale(np.sqrt(self.var_)) + # Extract the list of near constant features on the raw variances, + # before taking the square root. + constant_mask = self.var_ < 10 * np.finfo(X.dtype).eps + self.scale_ = _handle_zeros_in_scale( + np.sqrt(self.var_), copy=False, constant_mask=constant_mask) else: self.scale_ = None @@ -1078,7 +1087,7 @@ def partial_fit(self, X, y=None): self.n_samples_seen_ += X.shape[0] self.max_abs_ = max_abs - self.scale_ = _handle_zeros_in_scale(max_abs) + self.scale_ = _handle_zeros_in_scale(max_abs, copy=True) return self def transform(self, X): diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 974dad31258eb..67186f3991499 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -4,6 +4,7 @@ # # License: BSD 3 clause +from inspect import signature import warnings import itertools @@ -414,6 +415,61 @@ def test_standard_scaler_dtype(add_sample_weight, sparse_constructor): assert scaler.scale_.dtype == np.float64 +@pytest.mark.parametrize("scaler", [ + StandardScaler(with_mean=False), + RobustScaler(with_centering=False), +]) +@pytest.mark.parametrize("sparse_constructor", + [np.asarray, sparse.csc_matrix, sparse.csr_matrix]) +@pytest.mark.parametrize("add_sample_weight", [False, True]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("constant", [0, 1., 100.]) +def test_standard_scaler_constant_features( + scaler, add_sample_weight, sparse_constructor, dtype, constant): + if (isinstance(scaler, StandardScaler) + and constant > 1 + and sparse_constructor is not None + and add_sample_weight): + pytest.xfail("Computation of weighted variance is numerically unstable" + " for sparse data") + + if isinstance(scaler, RobustScaler) and add_sample_weight: + pytest.skip(f"{scaler.__class__.__name__} does not yet support" + f" sample_weight") + + rng = np.random.RandomState(0) + n_samples = 100 + n_features = 1 + if add_sample_weight: + fit_params = dict(sample_weight=rng.uniform(size=n_samples) * 2) + else: + fit_params = {} + X_array = np.full(shape=(n_samples, n_features), fill_value=constant, + dtype=dtype) + X = sparse_constructor(X_array) + X_scaled = scaler.fit(X, **fit_params).transform(X) + + if isinstance(scaler, StandardScaler): + # The variance info should be close to zero for constant features. + assert_allclose(scaler.var_, np.zeros(X.shape[1]), atol=1e-7) + + # Constant features should not be scaled (scale of 1.): + assert_allclose(scaler.scale_, np.ones(X.shape[1])) + + if hasattr(X_scaled, "toarray"): + assert_allclose(X_scaled.toarray(), X_array) + else: + assert_allclose(X_scaled, X) + + if isinstance(scaler, StandardScaler) and not add_sample_weight: + # Also check consistency with the standard scale function. + X_scaled_2 = scale(X, with_mean=scaler.with_mean) + if hasattr(X_scaled_2, "toarray"): + assert_allclose(X_scaled_2.toarray(), X_scaled_2.toarray()) + else: + assert_allclose(X_scaled_2, X_scaled_2) + + def test_scale_1d(): # 1-d inputs X_list = [1., 3., 5., 0.] @@ -538,12 +594,11 @@ def test_scaler_float16_overflow(): def test_handle_zeros_in_scale(): - s1 = np.array([0, 1, 2, 3]) + s1 = np.array([0, 1e-16, 1, 2, 3]) s2 = _handle_zeros_in_scale(s1, copy=True) - assert not s1[0] == s2[0] - assert_array_equal(s1, np.array([0, 1, 2, 3])) - assert_array_equal(s2, np.array([1, 1, 2, 3])) + assert_allclose(s1, np.array([0, 1e-16, 1, 2, 3])) + assert_allclose(s2, np.array([1, 1, 1, 2, 3])) def test_minmax_scaler_partial_fit(): From 660cc8ec9ca8fc03b18ebf87f4eae12a1e609e8e Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 22 Feb 2021 19:38:54 +0100 Subject: [PATCH 2/6] Remove useless import --- sklearn/preprocessing/tests/test_data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index 67186f3991499..f1b5be6e4be0b 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -4,7 +4,6 @@ # # License: BSD 3 clause -from inspect import signature import warnings import itertools From 59536defa9a8f137962a65d2fad83bd485294b6a Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Mon, 22 Feb 2021 19:40:25 +0100 Subject: [PATCH 3/6] Update changelog --- doc/whats_new/v1.0.rst | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/doc/whats_new/v1.0.rst b/doc/whats_new/v1.0.rst index 25e0b369bebd3..c84ece12ef71d 100644 --- a/doc/whats_new/v1.0.rst +++ b/doc/whats_new/v1.0.rst @@ -187,6 +187,13 @@ Changelog positioning strategy ``knots``. :pr:`18368` by :user:`Christian Lorentzen `. +- |Fix| :func:`preprocessing.scale`, :class:`preprocessing.StandardScaler` + and similar scalers detect near-constant features to avoid scaling them to + very large values. This problem happens in particular when using a scaler on + sparse data with a constant column with sample weights, in which case + centering is typically disabled. :pr:`19527` by :user:`Oliver Grisel + ` and :user:`Maria Telenczuk `. + :mod:`sklearn.tree` ................... From 139d37d5d2d9e5a108fe7cefa728a0c0ccc488bc Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 23 Feb 2021 16:23:18 +0100 Subject: [PATCH 4/6] Update sklearn/preprocessing/_data.py --- sklearn/preprocessing/_data.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index f8280a5374ac6..a0a8f51bf3222 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -61,9 +61,18 @@ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): - """Makes sure that whenever scale is zero, we handle it correctly. - - This happens in most scalers when we have constant features. + """Set scales of near constant features to 1. + + The goal is to avoid division by very small or zero values. + + Near constant features are detected automatically by identifying + scales close to machine precision unless they are precomputed by + the caller and passed with the `constant_mask` kwarg. + + Typically for standard scaling, the scales are the standard + deviation while near constant features are better detected on the + computed variances which are closer to machine precision by + construction. """ # if we are fitting on 1D arrays, scale might be a scalar if np.isscalar(scale): From d3015ebf1cdfbbf8ba27e4a499bd80a41e728a28 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Tue, 23 Feb 2021 16:25:18 +0100 Subject: [PATCH 5/6] Update sklearn/preprocessing/_data.py --- sklearn/preprocessing/_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/preprocessing/_data.py b/sklearn/preprocessing/_data.py index a0a8f51bf3222..29190dd6e2b67 100644 --- a/sklearn/preprocessing/_data.py +++ b/sklearn/preprocessing/_data.py @@ -68,7 +68,7 @@ def _handle_zeros_in_scale(scale, copy=True, constant_mask=None): Near constant features are detected automatically by identifying scales close to machine precision unless they are precomputed by the caller and passed with the `constant_mask` kwarg. - + Typically for standard scaling, the scales are the standard deviation while near constant features are better detected on the computed variances which are closer to machine precision by From b612785e0fd5931de4f9d554876451467ad2a523 Mon Sep 17 00:00:00 2001 From: Olivier Grisel Date: Wed, 24 Feb 2021 14:30:20 +0100 Subject: [PATCH 6/6] Update sklearn/preprocessing/tests/test_data.py --- sklearn/preprocessing/tests/test_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/preprocessing/tests/test_data.py b/sklearn/preprocessing/tests/test_data.py index f1b5be6e4be0b..fdd88be0ccff4 100644 --- a/sklearn/preprocessing/tests/test_data.py +++ b/sklearn/preprocessing/tests/test_data.py @@ -427,10 +427,11 @@ def test_standard_scaler_constant_features( scaler, add_sample_weight, sparse_constructor, dtype, constant): if (isinstance(scaler, StandardScaler) and constant > 1 - and sparse_constructor is not None + and sparse_constructor is not np.asarray and add_sample_weight): + # https://github.com/scikit-learn/scikit-learn/issues/19546 pytest.xfail("Computation of weighted variance is numerically unstable" - " for sparse data") + " for sparse data. See: #19546.") if isinstance(scaler, RobustScaler) and add_sample_weight: pytest.skip(f"{scaler.__class__.__name__} does not yet support"