8000 PERF fix overhead of _rescale_data in LinearRegression (#26207) · scikit-learn/scikit-learn@5a332e7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a332e7

Browse files
authored
PERF fix overhead of _rescale_data in LinearRegression (#26207)
1 parent 559609f commit 5a332e7

File tree

3 files changed

+163
-34
lines changed

3 files changed

+163
-34
lines changed

doc/whats_new/v1.3.rst

+7-2
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ Changelog
267267
meaning that it is note required to call `fit` before calling `transform`.
268268
Parameter validation only happens at `fit` time.
269269
:pr:`24230` by :user:`Guillaume Lemaitre <glemaitre>`.
270-
270+
271271
:mod:`sklearn.feature_selection`
272272
................................
273273

@@ -294,6 +294,11 @@ Changelog
294294
:mod:`sklearn.linear_model`
295295
...........................
296296

297+
- |Efficiency| Avoid data scaling when `sample_weight=None` and other
298+
unnecessary data copies and unexpected dense to sparse data conversion in
299+
:class:`linear_model.LinearRegression`.
300+
:pr:`26207` by :user:`Olivier Grisel <ogrisel>`.
301+
297302
- |Enhancement| :class:`linear_model.SGDClassifier`,
298303
:class:`linear_model.SGDRegressor` and :class:`linear_model.SGDOneClassSVM`
299304
now preserve dtype for `numpy.float32`.
@@ -309,7 +314,7 @@ Changelog
309314
:class:`linear_model.ARDRegression` to expose the actual number of iterations
310315
required to reach the stopping criterion.
311316
:pr:`25697` by :user:`John Pangas <jpangas>`.
312-
317+
313318
- |Fix| Use a more robust criterion to detect convergence of
314319
:class:`linear_model.LogisticRegression(penalty="l1", solver="liblinear")`
315320
on linearly separable problems.

sklearn/linear_model/_base.py

+72-25
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def _preprocess_data(
187187
fit_intercept,
188188
normalize=False,
189189
copy=True,
190+
copy_y=True,
190191
sample_weight=None,
191192
check_input=True,
192193
):
@@ -230,13 +231,14 @@ def _preprocess_data(
230231

231232
if check_input:
232233
X = check_array(X, copy=copy, accept_sparse=["csr", "csc"], dtype=FLOAT_DTYPES)
233-
elif copy:
234-
if sp.issparse(X):
235-
X = X.copy()
236-
else:
237-
X = X.copy(order="K")
238-
239-
y = np.asarray(y, dtype=X.dtype)
234+
y = check_array(y, dtype=X.dtype, copy=copy_y, ensure_2d=False)
235+
else:
236+
y = y.astype(X.dtype, copy=copy_y)
237+
if copy:
238+
if sp.issparse(X):
239+
X = X.copy()
240+
else:
241+
X = X.copy(order="K")
240242

241243
if fit_intercept:
242244
if sp.issparse(X):
@@ -276,7 +278,7 @@ def _preprocess_data(
276278
X_scale = np.ones(X.shape[1], dtype=X.dtype)
277279

278280
y_offset = np.average(y, axis< F438 /span>=0, weights=sample_weight)
279-
y = y - y_offset
281+
y -= y_offset
280282
else:
281283
X_offset = np.zeros(X.shape[1], dtype=X.dtype)
282284
X_scale = np.ones(X.shape[1], dtype=X.dtype)
@@ -293,7 +295,7 @@ def _preprocess_data(
293295
# sample_weight makes the refactoring tricky.
294296

295297

296-
def _rescale_data(X, y, sample_weight):
298+
def _rescale_data(X, y, sample_weight, inplace=False):
297299
"""Rescale data sample-wise by square root of sample_weight.
298300
299301
For many linear models, this enables easy support for sample_weight because
@@ -315,14 +317,37 @@ def _rescale_data(X, y, sample_weight):
315317
316318
y_rescaled : {array-like, sparse matrix}
317319
"""
320+
# Assume that _validate_data and _check_sample_weight have been called by
321+
# the caller.
318322
n_samples = X.shape[0]
319-
sample_weight = np.asarray(sample_weight)
320-
if sample_weight.ndim == 0:
321-
sample_weight = np.full(n_samples, sample_weight, dtype=sample_weight.dtype)
322323
sample_weight_sqrt = np.sqrt(sample_weight)
323-
sw_matrix = sparse.dia_matrix((sample_weight_sqrt, 0), shape=(n_samples, n_samples))
324-
X = safe_sparse_dot(sw_matrix, X)
325-
y = safe_sparse_dot(sw_matrix, y)
324+
325+
if sp.issparse(X) or sp.issparse(y):
326+
sw_matrix = sparse.dia_matrix(
327+
(sample_weight_sqrt, 0), shape=(n_samples, n_samples)
328+
)
329+
330+
if sp.issparse(X):
331+
X = safe_sparse_dot(sw_matrix, X)
332+
else:
333+
if inplace:
334+
X *= sample_weight_sqrt[:, np.newaxis]
335+
else:
336+
X = X * sample_weight_sqrt[:, np.newaxis]
337+
338+
if sp.issparse(y):
339+
y = safe_sparse_dot(sw_matrix, y)
340+
else:
341+
if inplace:
342+
if y.ndim == 1:
343+
y *= sample_weight_sqrt
344+
else:
345+
y *= sample_weight_sqrt[:, np.newaxis]
346+
else:
347+
if y.ndim == 1:
348+
y = y * sample_weight_sqrt
349+
else:
350+
y = y * sample_weight_sqrt[:, np.newaxis]
326351
return X, y, sample_weight_sqrt
327352

328353

@@ -651,20 +676,32 @@ def fit(self, X, y, sample_weight=None):
651676
X, y, accept_sparse=accept_sparse, y_numeric=True, multi_output=True
652677
)
653678

654-
sample_weight = _check_sample_weight(
655-
sample_weight, X, dtype=X.dtype, only_non_negative=True
656-
)
679+
has_sw = sample_weight is not None
680+
if has_sw:
681+
sample_weight = _check_sample_weight(
682+
sample_weight, X, dtype=X.dtype, only_non_negative=True
683+
)
684+
685+
# Note that neither _rescale_data nor the rest of the fit method of
686+
# LinearRegression can benefit from in-place operations when X is a
687+
# sparse matrix. Therefore, let's not copy X when it is sparse.
688+
copy_X_in_preprocess_data = self.copy_X and not sp.issparse(X)
657689

658690
X, y, X_offset, y_offset, X_scale = _preprocess_data(
659691
X,
660692
y,
661693
fit_intercept=self.fit_intercept,
662-
copy=self.copy_X,
694+
copy=copy_X_in_preprocess_data,
663695
sample_weight=sample_weight,
664696
)
665697

666-
# Sample weight can be implemented via a simple rescaling.
667-
X, y, sample_weight_sqrt = _rescale_data(X, y, sample_weight)
698+
if has_sw:
699+
# Sample weight can be implemented via a simple rescaling. Note
700+
# that we safely do inplace rescaling when _preprocess_data has
701+
# already made a copy if requested.
702+
X, y, sample_weight_sqrt = _rescale_data(
703+
X, y, sample_weight, inplace=copy_X_in_preprocess_data
704+
)
668705

669706
if self.positive:
670707
if y.ndim < 2:
@@ -678,11 +715,21 @@ def fit(self, X, y, sample_weight=None):
678715
elif sp.issparse(X):
679716
X_offset_scale = X_offset / X_scale
680717

681-
def matvec(b):
682-
return X.dot(b) - sample_weight_sqrt * b.dot(X_offset_scale)
718+
if has_sw:
719+
720+
def matvec(b):
721+
return X.dot(b) - sample_weight_sqrt * b.dot(X_offset_scale)
722+
723+
def rmatvec(b):
724+
return X.T.dot(b) - X_offset_scale * b.dot(sample_weight_sqrt)
725+
726+
else:
727+
728+
def matvec(b):
729+
return X.dot(b) - b.dot(X_offset_scale)
683730

684-
def rmatvec(b):
685-
return X.T.dot(b) - X_offset_scale * b.dot(sample_weight_sqrt)
731+
def rmatvec(b):
732+
return X.T.dot(b) - X_offset_scale * b.sum()
686733

687734
X_centered = sparse.linalg.LinearOperator(
688735
shape=X.shape, matvec=matvec, rmatvec=rmatvec

sklearn/linear_model/tests/test_base.py

+84-7
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,62 @@ def test_linear_regression_positive_vs_nonpositive_when_positive(global_random_s
315315
assert np.mean((reg.coef_ - regn.coef_) ** 2) < 1e-6
316316

317317

318+
@pytest.mark.parametrize("sparse_X", [True, False])
319+
@pytest.mark.parametrize("use_sw", [True, False])
320+
def test_inplace_data_preprocessing(sparse_X, use_sw, global_random_seed):
321+
# Check that the data is not modified inplace by the linear regression
322+
# estimator.
323+
rng = np.random.RandomState(global_random_seed)
324+
original_X_data = rng.randn(10, 12)
325+
original_y_data = rng.randn(10, 2)
326+
orginal_sw_data = rng.rand(10)
327+
328+
if sparse_X:
329+
X = sparse.csr_matrix(original_X_data)
330+
else:
331+
X = original_X_data.copy()
332+
y = original_y_data.copy()
333+
# XXX: Note hat y_sparse is not supported (broken?) in the current
334+
# implementation of LinearRegression.
335+
336+
if use_sw:
337+
sample_weight = orginal_sw_data.copy()
338+
else:
339+
sample_weight = None
340+
341+
# Do not allow inplace preprocessing of X and y:
342+
reg = LinearRegression()
343+
reg.fit(X, y, sample_weight=sample_weight)
344+
if sparse_X:
345+
assert_allclose(X.toarray(), original_X_data)
346+
else:
347+
assert_allclose(X, original_X_data)
348+
assert_allclose(y, original_y_data)
349+
350+
if use_sw:
351+
assert_allclose(sample_weight, orginal_sw_data)
352+
353+
# Allow inplace preprocessing of X and y
354+
reg = LinearRegression(copy_X=False)
355+
reg.fit(X, y, sample_weight=sample_weight)
356+
if sparse_X:
357+
# No optimization relying on the inplace modification of sparse input
358+
# data has been implemented at this time.
359+
assert_allclose(X.toarray(), original_X_data)
360+
else:
361+
# X has been offset (and optionally rescaled by sample weights)
362+
# inplace. The 0.42 threshold is arbitrary and has been found to be
363+
# robust to any random seed in the admissible range.
364+
assert np.linalg.norm(X - original_X_data) > 0.42
365+
366+
# y should not have been modified inplace by LinearRegression.fit.
367+
assert_allclose(y, original_y_data)
368+
369+
if use_sw:
370+
# Sample weights have no reason to ever be modified inplace.
371+
assert_allclose(sample_weight, orginal_sw_data)
372+
373+
318374
def test_linear_regression_pd_sparse_dataframe_warning():
319375
pd = pytest.importorskip("pandas")
320376

@@ -661,7 +717,8 @@ def test_dtype_preprocess_data(global_random_seed):
661717

662718

663719
@pytest.mark.parametrize("n_targets", [None, 2])
664-
def test_rescale_data_dense(n_targets, global_random_seed):
720+
@pytest.mark.parametrize("sparse_data", [True, False])
721+
def test_rescale_data(n_targets, sparse_data, global_random_seed):
665722
rng = np.random.RandomState(global_random_seed)
666723
n_samples = 200
667724
n_features = 2
@@ -672,14 +729,34 @@ def test_rescale_data_dense(n_targets, global_random_seed):
672729
y = rng.rand(n_samples)
673730
else:
674731
y = rng.rand(n_samples, n_targets)
675-
rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight)
676-
rescaled_X2 = X * sqrt_sw[:, np.newaxis]
732+
733+
expected_sqrt_sw = np.sqrt(sample_weight)
734+
expected_rescaled_X = X * expected_sqrt_sw[:, np.newaxis]
735+
677736
if n_targets is None:
678-
rescaled_y2 = y * sqrt_sw
737+
expected_rescaled_y = y * expected_sqrt_sw
679738
else:
680-
rescaled_y2 = y * sqrt_sw[:, np.newaxis]
681-
assert_array_almost_equal(rescaled_X, rescaled_X2)
682-
assert_array_almost_equal(rescaled_y, rescaled_y2)
739+
expected_rescaled_y = y * expected_sqrt_sw[:, np.newaxis]
740+
741+
if sparse_data:
742+
X = sparse.csr_matrix(X)
743+
if n_targets is None:
744+
y = sparse.csr_matrix(y.reshape(-1, 1))
745+
else:
746+
y = sparse.csr_matrix(y)
747+
748+
rescaled_X, rescaled_y, sqrt_sw = _rescale_data(X, y, sample_weight)
749+
750+
assert_allclose(sqrt_sw, expected_sqrt_sw)
751+
752+
if sparse_data:
753+
rescaled_X = rescaled_X.toarray()
754+
rescaled_y = rescaled_y.toarray()
755+
if n_targets is None:
756+
rescaled_y = rescaled_y.ravel()
757+
758+
assert_allclose(rescaled_X, expected_rescaled_X)
759+
assert_allclose(rescaled_y, expected_rescaled_y)
683760

684761

685762
def test_fused_types_make_dataset():

0 commit comments

Comments
 (0)
0