8000 Fix sparse logic in _rescale_data · scikit-learn/scikit-learn@c1e837f · GitHub
[go: up one dir, main page]

Skip to content

Commit c1e837f

Browse files
author
Christian Lorentzen
committed
Fix sparse logic in _rescale_data
1 parent 2bcffa3 commit c1e837f

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

sklearn/linear_model/_base.py

+17-11
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ def _preprocess_data(X, y, fit_intercept, normalize=False, copy=True,
184184
def _rescale_data(X, y, sample_weight, order='C'):
185185
"""Rescale data so as to support sample_weight"""
186186
n_samples = X.shape[0]
187+
sparse_X = sparse.issparse(X)
188+
sparse_y = sparse.issparse(y)
187189
sample_weight = np.array(sample_weight)
188190
if sample_weight.ndim == 0:
189191
sample_weight = np.full(n_samples, sample_weight,
@@ -193,22 +195,26 @@ def _rescale_data(X, y, sample_weight, order='C'):
193195
shape=(n_samples, n_samples))
194196
X = safe_sparse_dot(sw_matrix, X)
195197
y = safe_sparse_dot(sw_matrix, y)
196-
if sparse.issparse(X):
198+
199+
if sparse_X:
197200
if order == 'F':
198-
X = X.tocsc()
199-
if y.ndim > 1:
200-
y = y.tocsc()
201+
X = sparse.csc_matrix(X)
201202
else:
202-
X = X.tocsr()
203-
if y.ndim > 1:
204-
y = y.tocsr()
203+
X = sparse.csr_matrix(X)
204+
elif order == 'F':
205+
X = np.asfortranarray(X)
205206
else:
207+
X = np.ascontiguousarray(X)
208+
209+
if sparse_y:
206210
if order == 'F':
207-
X = np.asfortranarray(X)
208-
y = np.asfortranarray(y)
211+
y = sparse.csc_matrix(y)
209212
else:
210-
X = np.ascontiguousarray(X)
211-
y = np.ascontiguousarray(y)
213+
y = sparse.csr_matrix(y)
214+
elif order == 'F':
215+
y = np.asfortranarray(y)
216+
else:
217+
y = np.ascontiguousarray(y)
212218
return X, y
213219

214220

0 commit comments

Comments
 (0)
0