8000 FEA: Ridge support for Array API compliant inputs (#27800) · charlesjhill/scikit-learn@171e124 · GitHub
[go: up one dir, main page]

Skip to content

Commit 171e124

Browse files
fcharraselindgrenogriselbetatimlesteve
authored
FEA: Ridge support for Array API compliant inputs (scikit-learn#27800)
Co-authored-by: Eric Lindgren <ericlin@chalmers.se> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Tim Head <betatim@gmail.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com> Co-authored-by: Oleksii Kachaiev <kachayev@gmail.com> Co-authored-by: Julien Jerphanion <git@jjerphan.xyz> Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 2a2643f commit 171e124

File tree

9 files changed

+413
-87
lines changed

9 files changed

+413
-87
lines changed

doc/modules/array_api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Estimators
9595

9696
- :class:`decomposition.PCA` (with `svd_solver="full"`,
9797
`svd_solver="randomized"` and `power_iteration_normalizer="QR"`)
98+
- :class:`linear_model.Ridge` (with `solver="svd"`)
9899
- :class:`discriminant_analysis.LinearDiscriminantAnalysis` (with `solver="svd"`)
99100
- :class:`preprocessing.KernelCenterer`
100101
- :class:`preprocessing.MaxAbsScaler`

doc/whats_new/v1.5.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ See :ref:`array_api` for more details.
7272

7373
**Classes:**
7474

75+
- :class:`linear_model.Ridge` now supports the Array API for the `svd` solver.
76+
See :ref:`array_api` for more details.
77+
:pr:`27800` by :user:`Franck Charras <fcharras>`, :user:`Olivier Grisel <ogrisel>`
78+
and :user:`Tim Head <betatim>`.
79+
7580
Support for building with Meson
7681
-------------------------------
7782

sklearn/linear_model/_base.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,14 @@
3333
_fit_context,
3434
)
3535
from ..utils import check_array, check_random_state
36-
from ..utils._array_api import get_namespace, indexing_dtype
36+
from ..utils._array_api import (
37+
_asarray_with_order,
38+
_average,
39+
get_namespace,
40+
get_namespace_and_device,
41+
indexing_dtype,
42+
supported_float_dtypes,
43+
)
3744
from ..utils._seq_dataset import (
3845
ArrayDataset32,
3946
ArrayDataset64,
@@ -43,7 +50,7 @@
4350
from ..utils.extmath import safe_sparse_dot
4451
from ..utils.parallel import Parallel, delayed
4552
from ..utils.sparsefuncs import mean_variance_axis
46-
from ..utils.validation import FLOAT_DTYPES, _check_sample_weight, check_is_fitted
53+
from ..utils.validation import _check_sample_weight, check_is_fitted
4754

4855
# TODO: bayesian_ridge_regression and bayesian_regression_ard
4956
# should be squashed into its respective objects.
@@ -155,43 +162,51 @@ def _preprocess_data(
155162
Always an array of ones. TODO: refactor the code base to make it
156163
possible to remove this unused variable.
157164
"""
165+
xp, _, device_ = get_namespace_and_device(X, y, sample_weight)
166+
n_samples, n_features = X.shape
167+
X_is_sparse = sp.issparse(X)
168+
158169
if isinstance(sample_weight, numbers.Number):
159170
sample_weight = None
160171
if sample_weight is not None:
161-
sample_weight = np.asarray(sample_weight)
172+
sample_weight = xp.asarray(sample_weight)
162173

163174
if check_input:
164-
X = check_array(X, copy=copy, accept_sparse=["csr", "csc"], dtype=FLOAT_DTYPES)
175+
X = check_array(
176+
X, copy=copy, accept_sparse=["csr", "csc"], dtype=supported_float_dtypes(xp)
177+
)
165178
y = check_array(y, dtype=X.dtype, copy=copy_y, ensure_2d=False)
166179
else:
167-
y = y.astype(X.dtype, copy=copy_y)
180+
y = xp.astype(y, X.dtype, copy=copy_y)
168181
if copy:
169-
if sp.issparse(X):
182+
if X_is_sparse:
170183
X = X.copy()
171184
else:
172-
X = X.copy(order="K")
185+
X = _asarray_with_order(X, order="K", copy=True, xp=xp)
186+
187+
dtype_ = X.dtype
173188

174189
if fit_intercept:
175-
if sp.issparse(X):
190+
if X_is_sparse:
176191
X_offset, X_var = mean_variance_axis(X, axis=0, weights=sample_weight)
177192
else:
178-
X_offset = np.average(X, axis=0, weights=sample_weight)
193+
X_offset = _average(X, axis=0, weights=sample_weight, xp=xp)
179194

180-
X_offset = X_offset.astype(X.dtype, copy=False)
195+
X_offset = xp.astype(X_offset, X.dtype, copy=False)
181196
X -= X_offset
182197

183-
y_offset = np.average(y, axis=0, weights=sample_weight)
198+
y_offset = _average(y, axis=0, weights=sample_weight, xp=xp)
184199
y -= y_offset
185200
else:
186-
X_offset = np.zeros(X.shape[1], dtype=X.dtype)
201+
X_offset = xp.zeros(n_features, dtype=X.dtype, device=device_)
187202
if y.ndim == 1:
188-
y_offset = X.dtype.type(0)
203+
y_offset = xp.asarray(0.0, dtype=dtype_, device=device_)
189204
else:
190-
y_offset = np.zeros(y.shape[1], dtype=X.dtype)
205+
y_offset = xp.zeros(y.shape[1], dtype=dtype_, device=device_)
191206

192207
# XXX: X_scale is no longer needed. It is an historic artifact from the
193208
# time where linear model exposed the normalize parameter.
194-
X_scale = np.ones(X.shape[1], dtype=X.dtype)
209+
X_scale = xp.ones(n_features, dtype=X.dtype, device=device_)
195210
return X, y, X_offset, y_offset, X_scale
196211

197212

@@ -224,8 +239,9 @@ def _rescale_data(X, y, sample_weight, inplace=False):
224239
"""
225240
# Assume that _validate_data and _check_sample_weight have been called by
226241
# the caller.
242+
xp, _ = get_namespace(X, y, sample_weight)
227243
n_samples = X.shape[0]
228-
sample_weight_sqrt = np.sqrt(sample_weight)
244+
sample_weight_sqrt = xp.sqrt(sample_weight)
229245

230246
if sp.issparse(X) or sp.issparse(y):
231247
sw_matrix = sparse.dia_matrix(
@@ -236,9 +252,9 @@ def _rescale_data(X, y, sample_weight, inplace=False):
236252
X = safe_sparse_dot(sw_matrix, X)
237253
else:
238254
if inplace:
239-
X *= sample_weight_sqrt[:, np.newaxis]
255+
X *= sample_weight_sqrt[:, None]
240256
else:
241-
X = X * sample_weight_sqrt[:, np.newaxis]
257+
X = X * sample_weight_sqrt[:, None]
242258

243259
if sp.issparse(y):
244260
y = safe_sparse_dot(sw_matrix, y)
@@ -247,12 +263,12 @@ def _rescale_data(X, y, sample_weight, inplace=False):
247263
if y.ndim == 1:
248264
y *= sample_weight_sqrt
249265
else:
250-
y *= sample_weight_sqrt[:, np.newaxis]
266+
y *= sample_weight_sqrt[:, None]
251267
else:
252268
if y.ndim == 1:
253269
y = y * sample_weight_sqrt
254270
else:
255-
y = y * sample_weight_sqrt[:, np.newaxis]
271+
y = y * sample_weight_sqrt[:, None]
256272
return X, y, sample_weight_sqrt
257273

258274

@@ -267,7 +283,11 @@ def _decision_function(self, X):
267283
check_is_fitted(self)
268284

269285
X = self._validate_data(X, accept_sparse=["csr", "csc", "coo"], reset=False)
270-
return safe_sparse_dot(X, self.coef_.T, dense_output=True) + self.intercept_
286+
coef_ = self.coef_
287+
if coef_.ndim == 1:
288+
return X @ coef_ + self.intercept_
289+
else:
290+
return X @ coef_.T + self.intercept_
271291

272292
def predict(self, X):
273293
AD86 """
@@ -287,11 +307,22 @@ def predict(self, X):
287307

288308
def _set_intercept(self, X_offset, y_offset, X_scale):
289309
"""Set the intercept_"""
310+
311+
xp, _ = get_namespace(X_offset, y_offset, X_scale)
312+
290313
if self.fit_intercept:
291314
# We always want coef_.dtype=X.dtype. For instance, X.dtype can differ from
292315
# coef_.dtype if warm_start=True.
293-
self.coef_ = np.divide(self.coef_, X_scale, dtype=X_scale.dtype)
294-
self.intercept_ = y_offset - np.dot(X_offset, self.coef_.T)
316+
coef_ = xp.astype(self.coef_, X_scale.dtype, copy=False)
317+
coef_ = self.coef_ = xp.divide(coef_, X_scale)
318+
319+
if coef_.ndim == 1:
320+
intercept_ = y_offset - X_offset @ coef_
321+
else:
322+
intercept_ = y_offset - X_offset @ coef_.T
323+
324+
self.intercept_ = intercept_
325+
295326
else:
296327
self.intercept_ = 0.0
297328

0 commit comments

Comments
 (0)
0