8000 Remove extra data copies when not needed · scikit-learn/scikit-learn@00a69de · GitHub
[go: up one dir, main page]

Skip to content

Commit 00a69de

Browse files
committed
Remove extra data copies when not needed
1 parent 278c641 commit 00a69de

File tree

2 files changed

+92
-19
lines changed

2 files changed

+92
-19
lines changed

sklearn/linear_model/_base.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def _preprocess_data(
187187
fit_intercept,
188188
normalize=False,
189189
copy=True,
190+
copy_y=True,
190191
sample_weight=None,
191192
check_input=True,
192193
):
@@ -230,13 +231,14 @@ def _preprocess_data(
230231

231232
if check_input:
232233
X = check_array(X, copy=copy, accept_sparse=["csr", "csc"], dtype=FLOAT_DTYPES)
233-
elif copy:
234-
if sp.issparse(X):
235-
X = X.copy()
236-
else:
237-
X = X.copy(order="K")
238-
239-
y = np.asarray(y, dtype=X.dtype)
234+
y = check_array(y, dtype=X.dtype, copy=copy_y, ensure_2d=False)
235+
else:
236+
y = y.astype(X.dtype, copy=copy_y)
237+
if copy:
238+
if sp.issparse(X):
239+
X = X.copy()
240+
else:
241+
X = X.copy(order="K")
240242

241243
if fit_intercept:
242244
if sp.issparse(X):
@@ -276,7 +278,7 @@ def _preprocess_data(
276278
X_scale = np.ones(X.shape[1], dtype=X.dtype)
277279

278280
y_offset = np.average(y, axis=0, weights=sample_weight)
279-
y = y - y_offset
281+
y -= y_offset
280282
else:
281283
X_offset = np.zeros(X.shape[1], dtype=X.dtype)
282284
X_scale = np.ones(X.shape[1], dtype=X.dtype)
@@ -293,7 +295,7 @@ def _preprocess_data(
293295
# sample_weight makes the refactoring tricky.
294296

295297

296-
def _rescale_data(X, y, sample_weight):
298+
def _rescale_data(X, y, sample_weight, inplace=False):
297299
"""Rescale data sample-wise by square root of sample_weight.
298300
299301
For many linear models, this enables easy support for sample_weight because
@@ -328,18 +330,24 @@ def _rescale_data(X, y, sample_weight):
328330
if sp.issparse(X):
329331
X = safe_sparse_dot(sw_matrix, X)
330332
else:
331-
# XXX: we do not do inplace multiplication on X for consistency
332-
# with the sparse case and because the _rescale_data currently
333-
# does not make it explicit if it's ok to do it or not.
334-
X = X * sample_weight_sqrt[:, np.newaxis]
333+
if inplace:
334+
X *= sample_weight_sqrt[:, np.newaxis]
335+
else:
336+
X = X * sample_weight_sqrt[:, np.newaxis]
335337

336338
if sp.issparse(y):
337339
y = safe_sparse_dot(sw_matrix, y)
338340
else:
339-
if y.ndim == 1:
340-
y = y * sample_weight_sqrt
341+
if inplace:
342+
if y.ndim == 1:
343+
y *= sample_weight_sqrt
344+
else:
345+
y *= sample_weight_sqrt[:, np.newaxis]
341346
else:
342-
y = y * sample_weight_sqrt[:, np.newaxis]
347+
if y.ndim == 1:
348+
y = y * sample_weight_sqrt
349+
else:
350+
y = y * sample_weight_sqrt[:, np.newaxis]
343351
return X, y, sample_weight_sqrt
344352

345353

@@ -674,17 +682,26 @@ def fit(self, X, y, sample_weight=None):
674682
sample_weight, X, dtype=X.dtype, only_non_negative=True
675683
)
676684

685+
# Note that neither _rescale_data nor the rest of the fit method of
686+
# LinearRegression can benefit from in-place operations when X is a
687+
# sparse matrix. Therefore, let's not copy X when it is sparse.
688+
copy_X_in_preprocess_data = self.copy_X and not sp.issparse(X)
689+
677690
X, y, X_offset, y_offset, X_scale = _preprocess_data(
678691
X,
679692
y,
680693
fit_intercept=self.fit_intercept,
681-
copy=self.copy_X,
694+
copy=copy_X_in_preprocess_data,
682695
sample_weight=sample_weight,
683696
)
684697

685-
# Sample weight can be implemented via a simple rescaling.
686698
if has_sw:
687-
X, y, sample_weight_sqrt = _rescale_data(X, y, sample_weight)
699+
# Sample weight can be implemented via a simple rescaling. Note
700+
# that we safely do inplace rescaling when _preprocess_data has
701+
# already made a copy if requested.
702+
X, y, sample_weight_sqrt = _rescale_data(
703+
X, y, sample_weight, inplace=copy_X_in_preprocess_data
704+
)
688705

689706
if self.positive:
690707
if y.ndim < 2:
< 3271 div class="d-flex flex-row">

sklearn/linear_model/tests/test_base.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,62 @@ def test_linear_regression_positive_vs_nonpositive_when_positive(global_random_s
315315
assert np.mean((reg.coef_ - regn.coef_) ** 2) < 1e-6
316316

317317

318+
@pytest.mark.parametrize("sparse_X", [True, False])
319+
@pytest.mark.parametrize("use_sw", [True, False])
320+
def test_inplace_data_preprocessing(sparse_X, use_sw, global_random_seed):
321+
# Check that the data is not modified inplace by the linear regression
322+
# estimator.
323+
rng = np.random.RandomState(global_random_seed)
324+
original_X_data = rng.randn(10, 12)
325+
original_y_data = rng.randn(10, 2)
326+
orginal_sw_data = rng.rand(10)
327+
328+
if sparse_X:
329+
X = sparse.csr_matrix(original_X_data)
330+
else:
331+
X = original_X_data.copy()
332+
y = original_y_data.copy()
333+
# XXX: Note hat y_sparse is not supported (broken?) in the current
334+
# implementation of LinearRegression.
335+
336+
if use_sw:
337+
sample_weight = orginal_sw_data.copy()
338+
else:
339+
sample_weight = None
340+
341+
# Do not allow inplace preprocessing of X and y:
342+
reg = LinearRegression()
343+
reg.fit(X, y, sample_weight=sample_weight)
344+
if sparse_X:
345+
assert_allclose(X.toarray(), original_X_data)
346+
else:
347+
assert_allclose(X, original_X_data)
348+
assert_allclose(y, original_y_data)
349+
350+
if use_sw:
351+
assert_allclose(sample_weight, orginal_sw_data)
352+
353+
# Allow inplace preprocessing of X and y
354+
reg = LinearRegression(copy_X=False)
355+
reg.fit(X, y, sample_weight=sample_weight)
356+
if sparse_X:
357+
# No optimization relying on the inplace modification of sparse input
358+
# data has been implemented at this time.
359+
assert_allclose(X.toarray(), original_X_data)
360+
else:
361+
# X has been offset (and optionally rescaled by sample weights)
362+
# inplace. The 0.42 threshold is arbitrary and has been found to be
363+
# robust to any random seed in the admissible range.
364+
assert np.linalg.norm(X - original_X_data) > 0.42
365+
366+
# y should not have been modified inplace by LinearRegression.fit.
367+
assert_allclose(y, original_y_data)
368+
369+
if use_sw:
370+
# Sample weights have no reason to ever be modified inplace.
371+
assert_allclose(sample_weight, orginal_sw_data)
372+
373+
318374
def test_linear_regression_pd_sparse_dataframe_warning():
319375
pd = pytest.importorskip("pandas")
320376

0 commit comments

Comments
 (0)
0