From 87dbe430181278267d62c2ba14d49757979d5498 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 Jul 2018 12:57:44 -0400 Subject: [PATCH 1/4] Add `check_input` option to `_preprocess_data` Includes a `check_input` option to `_preprocess_data`, which checks to see if `check_array` should be run on the input array or not. As the `check_array` step can be expensive (particularly when trying to detect non-finite values), it is important to have an option to skip this step particularly when it shows up in a tight loop like with Lasso in Dictionary Learning. The default is to keep this check. Make sure to still copy the array if that is requested even if the check is disabled. --- sklearn/linear_model/base.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index 09c389cb336d7..e6d504189c7a8 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -68,7 +68,7 @@ def make_dataset(X, y, sample_weight, random_state=None): def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True, - sample_weight=None, return_mean=False): + sample_weight=None, return_mean=False, check_input=True): """ Centers data to have mean zero along axis 0. If fit_intercept=False or if the X is a sparse matrix, no centering is done, but normalization can still @@ -90,8 +90,15 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True, if isinstance(sample_weight, numbers.Number): sample_weight = None - X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'], - dtype=FLOAT_DTYPES) + if check_input: + X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'], + dtype=FLOAT_DTYPES) + elif copy: + if sp.issparse(X): + X = X.copy() + else: + X = X.copy(order='K') + y = np.asarray(y, dtype=X.dtype) if fit_intercept: From f21ba6f51bb8d31a08fd55f307cdc3d95e5044d8 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 Jul 2018 12:57:46 -0400 Subject: [PATCH 2/4] Add option to disable `check_input` in `_pre_fit` Many of the places using `_preprocess_data` do so directly and/or via `_pre_fit`. As such it makes sense to expose the `check_input` parameter in `_pre_fit` as well. --- sklearn/linear_model/base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/linear_model/base.py b/sklearn/linear_model/base.py index e6d504189c7a8..059c8f9939b23 100644 --- a/sklearn/linear_model/base.py +++ b/sklearn/linear_model/base.py @@ -448,7 +448,8 @@ def fit(self, X, y, sample_weight=None): return self -def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy): +def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy, + check_input=True): """Aux function used at beginning of fit in linear models""" n_samples, n_features = X.shape @@ -457,11 +458,12 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy): precompute = False X, y, X_offset, y_offset, X_scale = _preprocess_data( X, y, fit_intercept=fit_intercept, normalize=normalize, - copy=False, return_mean=True) + copy=False, return_mean=True, check_input=check_input) else: # copy was done in fit if necessary X, y, X_offset, y_offset, X_scale = _preprocess_data( - X, y, fit_intercept=fit_intercept, normalize=normalize, copy=copy) + X, y, fit_intercept=fit_intercept, normalize=normalize, copy=copy, + check_input=check_input) if hasattr(precompute, '__array__') and ( fit_intercept and not np.allclose(X_offset, np.zeros(n_features)) or normalize and not np.allclose(X_scale, np.ones(n_features))): From 9a59e7381203bc7240eb60c353e577afbe966156 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Thu, 12 Jul 2018 13:32:52 -0400 Subject: [PATCH 3/4] Pass `check_input` to `_pre_fit` in `ElasticNet` If we want to skip `check_array` calls in `ElasticNet` or `Lasso`, we should disable them in `_pre_fit` as well. Otherwise we will still have to do the same checks that we skipped previously. --- sklearn/linear_model/coordinate_descent.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/linear_model/coordinate_descent.py b/sklearn/linear_model/coordinate_descent.py index bdad75bc6197a..f78f917c7b16b 100644 --- a/sklearn/linear_model/coordinate_descent.py +++ b/sklearn/linear_model/coordinate_descent.py @@ -418,7 +418,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None, if check_input: X, y, X_offset, y_offset, X_scale, precompute, Xy = \ _pre_fit(X, y, Xy, precompute, normalize=False, - fit_intercept=False, copy=False) + fit_intercept=False, copy=False, check_input=check_input) if alphas is None: # No need to normalize of fit_intercept: it has been done # above @@ -717,7 +717,8 @@ def fit(self, X, y, check_input=True): should_copy = self.copy_X and not X_copied X, y, X_offset, y_offset, X_scale, precompute, Xy = \ _pre_fit(X, y, None, self.precompute, self.normalize, - self.fit_intercept, copy=should_copy) + self.fit_intercept, copy=should_copy, + check_input=check_input) if y.ndim == 1: y = y[:, np.newaxis] if Xy is not None and Xy.ndim == 1: From 387fb13f8265ded80d3bb522f3bdbcfd0fa30342 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 18 Jul 2018 12:19:55 -0400 Subject: [PATCH 4/4] TST: Run _preprocess_data on sparse without checks Make sure that even when checks are disabled, the sparse input still gets copied if requested. --- sklearn/linear_model/tests/test_base.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/sklearn/linear_model/tests/test_base.py b/sklearn/linear_model/tests/test_base.py index 30e4cfdcced42..bcabe12ed35f3 100644 --- a/sklearn/linear_model/tests/test_base.py +++ b/sklearn/linear_model/tests/test_base.py @@ -3,6 +3,8 @@ # # License: BSD 3 clause +import pytest + import numpy as np from scipy import sparse from scipy import linalg @@ -321,6 +323,28 @@ def test_csr_preprocess_data(): assert_equal(csr_.getformat(), 'csr') +@pytest.mark.parametrize('is_sparse', (True, False)) +@pytest.mark.parametrize('to_copy', (True, False)) +def test_preprocess_copy_data_no_checks(is_sparse, to_copy): + X, y = make_regression() + X[X < 2.5] = 0.0 + + if is_sparse: + X = sparse.csr_matrix(X) + + X_, y_, _, _, _ = _preprocess_data(X, y, True, + copy=to_copy, check_input=False) + + if to_copy and is_sparse: + assert not np.may_share_memory(X_.data, X.data) + elif to_copy: + assert not np.may_share_memory(X_, X) + elif is_sparse: + assert np.may_share_memory(X_.data, X.data) + else: + assert np.may_share_memory(X_, X) + + def test_dtype_preprocess_data(): n_samples = 200 n_features = 2