8000 [MRG] MNT: Optionally skip another input check in ElasticNet, Lasso by jakirkham · Pull Request #11487 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

[MRG] MNT: Optionally skip another input check in ElasticNet, Lasso #11487

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions sklearn/linear_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def make_dataset(X, y, sample_weight, random_state=None):


def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
sample_weight=None, return_mean=False):
sample_weight=None, return_mean=False, check_input=True):
"""
Centers data to have mean zero along axis 0. If fit_intercept=False or if
the X is a sparse matrix, no centering is done, but normalization can still
Expand All @@ -90,8 +90,15 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
if isinstance(sample_weight, numbers.Number):
sample_weight = None

X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
dtype=FLOAT_DTYPES)
if check_input:
X = check_array(X, copy=copy, accept_sparse=['csr', 'csc'],
dtype=FLOAT_DTYPES)
elif copy:
if sp.issparse(X):
X = X.copy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't covered by tests.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, think I've fixed this.

else:
X = X.copy(order='K')

y = np.asarray(y, dtype=X.dtype)

if fit_intercept:
Expand Down Expand Up @@ -441,7 +448,8 @@ def fit(self, X, y, sample_weight=None):
return self


def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy,
check_input=True):
"""Aux function used at beginning of fit in linear models"""
n_samples, n_features = X.shape

Expand All @@ -450,11 +458,12 @@ def _pre_fit(X, y, Xy, precompute, normalize, fit_intercept, copy):
precompute = False
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, fit_intercept=fit_intercept, normalize=normalize,
copy=False, return_mean=True)
copy=False, return_mean=True, check_input=check_input)
else:
# copy was done in fit if necessary
X, y, X_offset, y_offset, X_scale = _preprocess_data(
X, y, fit_intercept=fit_intercept, normalize=normalize, copy=copy)
X, y, fit_intercept=fit_intercept, normalize=normalize, copy=copy,
check_input=check_input)
if hasattr(precompute, '__array__') and (
fit_intercept and not np.allclose(X_offset, np.zeros(n_features)) or
normalize and not np.allclose(X_scale, np.ones(n_features))):
Expand Down
5 changes: 3 additions & 2 deletions sklearn/linear_model/coordinate_descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def enet_path(X, y, l1_ratio=0.5, eps=1e-3, n_alphas=100, alphas=None,
if check_input:
X, y, X_offset, y_offset, X_scale, precompute, Xy = \
_pre_fit(X, y, Xy, precompute, normalize=False,
fit_intercept=False, copy=False)
fit_intercept=False, copy=False, check_input=check_input)
if alphas is None:
# No need to normalize of fit_intercept: it has been done
# above
Expand Down Expand Up @@ -717,7 +717,8 @@ def fit(self, X, y, check_input=True):
should_copy = self.copy_X and not X_copied
X, y, X_offset, y_offset, X_scale, precompute, Xy = \
_pre_fit(X, y, None, self.precompute, self.normalize,
self.fit_intercept, copy=should_copy)
self.fit_intercept, copy=should_copy,
check_input=check_input)
if y.ndim == 1:
y = y[:, np.newaxis]
if Xy is not None and Xy.ndim == 1:
Expand Down
24 changes: 24 additions & 0 deletions sklearn/linear_model/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#
# License: BSD 3 clause

import pytest

import numpy as np
from scipy import sparse
from scipy import linalg
Expand Down Expand Up @@ -321,6 +323,28 @@ def test_csr_preprocess_data():
assert_equal(csr_.getformat(), 'csr')


@pytest.mark.parametrize('is_sparse', (True, False))
@pytest.mark.parametrize('to_copy', (True, False))
def test_preprocess_copy_data_no_checks(is_sparse, to_copy):
X, y = make_regression()
X[X < 2.5] = 0.0

if is_sparse:
X = sparse.csr_matrix(X)

X_, y_, _, _, _ = _preprocess_data(X, y, True,
copy=to_copy, check_input=False)

if to_copy and is_sparse:
assert not np.may_share_memory(X_.data, X.data)
elif to_copy:
assert not np.may_share_memory(X_, X)
elif is_sparse:
assert np.may_share_memory(X_.data, X.data)
else:
assert np.may_share_memory(X_, X)


def test_dtype_preprocess_data():
n_samples = 200
n_features = 2
Expand Down
0