33
33
_fit_context ,
34
34
)
35
35
from ..utils import check_array , check_random_state
36
- from ..utils ._array_api import get_namespace , indexing_dtype
36
+ from ..utils ._array_api import (
37
+ _asarray_with_order ,
38
+ _average ,
39
+ get_namespace ,
40
+ get_namespace_and_device ,
41
+ indexing_dtype ,
42
+ supported_float_dtypes ,
43
+ )
37
44
from ..utils ._seq_dataset import (
38
45
ArrayDataset32 ,
39
46
ArrayDataset64 ,
43
50
from ..utils .extmath import safe_sparse_dot
44
51
from ..utils .parallel import Parallel , delayed
45
52
from ..utils .sparsefuncs import mean_variance_axis
46
- from ..utils .validation import FLOAT_DTYPES , _check_sample_weight , check_is_fitted
53
+ from ..utils .validation import _check_sample_weight , check_is_fitted
47
54
48
55
# TODO: bayesian_ridge_regression and bayesian_regression_ard
49
56
# should be squashed into its respective objects.
@@ -155,43 +162,51 @@ def _preprocess_data(
155
162
Always an array of ones. TODO: refactor the code base to make it
156
163
possible to remove this unused variable.
157
164
"""
165
+ xp , _ , device_ = get_namespace_and_device (X , y , sample_weight )
166
+ n_samples , n_features = X .shape
167
+ X_is_sparse = sp .issparse (X )
168
+
158
169
if isinstance (sample_weight , numbers .Number ):
159
170
sample_weight = None
160
171
if sample_weight is not None :
161
- sample_weight = np .asarray (sample_weight )
172
+ sample_weight = xp .asarray (sample_weight )
162
173
163
174
if check_input :
164
- X = check_array (X , copy = copy , accept_sparse = ["csr" , "csc" ], dtype = FLOAT_DTYPES )
175
+ X = check_array (
176
+ X , copy = copy , accept_sparse = ["csr" , "csc" ], dtype = supported_float_dtypes (xp )
177
+ )
165
178
y = check_array (y , dtype = X .dtype , copy = copy_y , ensure_2d = False )
166
179
else :
167
- y = y .astype (X .dtype , copy = copy_y )
180
+ y = xp .astype (y , X .dtype , copy = copy_y )
168
181
if copy :
169
- if sp . issparse ( X ) :
182
+ if X_is_sparse :
170
183
X = X .copy ()
171
184
else :
172
- X = X .copy (order = "K" )
185
+ X = _asarray_with_order (X , order = "K" , copy = True , xp = xp )
186
+
187
+ dtype_ = X .dtype
173
188
174
189
if fit_intercept :
175
- if sp . issparse ( X ) :
190
+ if X_is_sparse :
176
191
X_offset , X_var = mean_variance_axis (X , axis = 0 , weights = sample_weight )
177
192
else :
178
- X_offset = np . average (X , axis = 0 , weights = sample_weight )
193
+ X_offset = _average (X , axis = 0 , weights = sample_weight , xp = xp )
179
194
180
- X_offset = X_offset .astype (X .dtype , copy = False )
195
+ X_offset = xp .astype (X_offset , X .dtype , copy = False )
181
196
X -= X_offset
182
197
183
- y_offset = np . average (y , axis = 0 , weights = sample_weight )
198
+ y_offset = _average (y , axis = 0 , weights = sample_weight , xp = xp )
184
199
y -= y_offset
185
200
else :
186
- X_offset = np .zeros (X . shape [ 1 ] , dtype = X .dtype )
201
+ X_offset = xp .zeros (n_features , dtype = X .dtype , device = device_ )
187
202
if y .ndim == 1 :
188
- y_offset = X . dtype . type ( 0 )
203
+ y_offset = xp . asarray ( 0.0 , dtype = dtype_ , device = device_ )
189
204
else :
190
- y_offset = np .zeros (y .shape [1 ], dtype = X . dtype )
205
+ y_offset = xp .zeros (y .shape [1 ], dtype = dtype_ , device = device_ )
191
206
192
207
# XXX: X_scale is no longer needed. It is an historic artifact from the
193
208
# time where linear model exposed the normalize parameter.
194
- X_scale = np .ones (X . shape [ 1 ] , dtype = X .dtype )
209
+ X_scale = xp .ones (n_features , dtype = X .dtype , device = device_ )
195
210
return X , y , X_offset , y_offset , X_scale
196
211
197
212
@@ -224,8 +239,9 @@ def _rescale_data(X, y, sample_weight, inplace=False):
224
239
"""
225
240
# Assume that _validate_data and _check_sample_weight have been called by
226
241
# the caller.
242
+ xp , _ = get_namespace (X , y , sample_weight )
227
243
n_samples = X .shape [0 ]
228
- sample_weight_sqrt = np .sqrt (sample_weight )
244
+ sample_weight_sqrt = xp .sqrt (sample_weight )
229
245
230
246
if sp .issparse (X ) or sp .issparse (y ):
231
247
sw_matrix = sparse .dia_matrix (
@@ -236,9 +252,9 @@ def _rescale_data(X, y, sample_weight, inplace=False):
236
252
X = safe_sparse_dot (sw_matrix , X )
237
253
else :
238
254
if inplace :
239
- X *= sample_weight_sqrt [:, np . newaxis ]
255
+ X *= sample_weight_sqrt [:, None ]
240
256
else :
241
- X = X * sample_weight_sqrt [:, np . newaxis ]
257
+ X = X * sample_weight_sqrt [:, None ]
242
258
243
259
if sp .issparse (y ):
244
260
y = safe_sparse_dot (sw_matrix , y )
@@ -247,12 +263,12 @@ def _rescale_data(X, y, sample_weight, inplace=False):
247
263
if y .ndim == 1 :
248
264
y *= sample_weight_sqrt
249
265
else :
250
- y *= sample_weight_sqrt [:, np . newaxis ]
266
+ y *= sample_weight_sqrt [:, None ]
251
267
else :
252
268
if y .ndim == 1 :
253
269
y = y * sample_weight_sqrt
254
270
else :
255
- y = y * sample_weight_sqrt [:, np . newaxis ]
271
+ y = y * sample_weight_sqrt [:, None ]
256
272
return X , y , sample_weight_sqrt
257
273
258
274
@@ -267,7 +283,11 @@ def _decision_function(self, X):
267
283
check_is_fitted (self )
268
284
269
285
X = self ._validate_data (X , accept_sparse = ["csr" , "csc" , "coo" ], reset = False )
270
- return safe_sparse_dot (X , self .coef_ .T , dense_output = True ) + self .intercept_
286
+ coef_ = self .coef_
287
+ if coef_ .ndim == 1 :
288
+ return X @ coef_ + self .intercept_
289
+ else :
290
+ return X @ coef_ .T + self .intercept_
271
291
272
292
def predict (self , X ):
273
293
AD86
"""
@@ -287,11 +307,22 @@ def predict(self, X):
287
307
288
308
def _set_intercept (self , X_offset , y_offset , X_scale ):
289
309
"""Set the intercept_"""
310
+
311
+ xp , _ = get_namespace (X_offset , y_offset , X_scale )
312
+
290
313
if self .fit_intercept :
291
314
# We always want coef_.dtype=X.dtype. For instance, X.dtype can differ from
292
315
# coef_.dtype if warm_start=True.
293
- self .coef_ = np .divide (self .coef_ , X_scale , dtype = X_scale .dtype )
294
- self .intercept_ = y_offset - np .dot (X_offset , self .coef_ .T )
316
+ coef_ = xp .astype (self .coef_ , X_scale .dtype , copy = False )
317
+ coef_ = self .coef_ = xp .divide (coef_ , X_scale )
318
+
319
+ if coef_ .ndim == 1 :
320
+ intercept_ = y_offset - X_offset @ coef_
321
+ else :
322
+ intercept_ = y_offset - X_offset @ coef_ .T
323
+
324
+ self .intercept_ = intercept_
325
+
295
326
else :
296
327
self .intercept_ = 0.0
297
328
0 commit comments