8000 ENH Array API support for euclidean_distances and rbf_kernel (#29433) · scikit-learn/scikit-learn@e7af195 · GitHub
[go: up one dir, main page]

Skip to content

Commit e7af195

Browse files
OmarManzoorogrisel
andauthored
ENH Array API support for euclidean_distances and rbf_kernel (#29433)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 2b2e290 commit e7af195

File tree

8 files changed

+184
-46
lines changed

8 files changed

+184
-46
lines changed

doc/modules/array_api.rst

+18
9E81
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ Metrics
123123
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
124124
- :func:`sklearn.metrics.pairwise.chi2_kernel`
125125
- :func:`sklearn.metrics.pairwise.cosine_similarity`
126+
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
126127
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
128+
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
127129
- :func:`sklearn.metrics.r2_score`
128130
- :func:`sklearn.metrics.zero_one_loss`
129131

@@ -172,6 +174,8 @@ automatically skipped. Therefore it's important to run the tests with the
172174
pip install array-api-compat # and other libraries as needed
173175
pytest -k "array_api" -v
174176

177+
.. _mps_support:
178+
175179
Note on MPS device support
176180
--------------------------
177181

@@ -191,3 +195,17 @@ To enable the MPS support in PyTorch, set the environment variable
191195

192196
At the time of writing all scikit-learn tests should pass, however, the
193197
computational speed is not necessarily better than with the CPU device.
198+
199+
.. _device_support_for_float64:
200+
201+
Note on device support for ``float64``
202+
--------------------------------------
203+
204+
Certain operations within scikit-learn will automatically perform operations
205+
on floating-point values with `float64` precision to prevent overflows and ensure
206+
correctness (e.g., :func:`metrics.pairwise.euclidean_distances`). However,
207+
certain combinations of array namespaces and devices, such as `PyTorch on MPS`
208+
(see :ref:`mps_support`) do not support the `float64` data type. In these cases,
209+
scikit-learn will revert to using the `float32` data type instead. This can result in
210+
different behavior (typically numerically unstable results) compared to not using array
211+
API dispatching or using a device with `float64` support.

doc/whats_new/v1.6.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ See :ref:`array_api` for more details.
4343
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
4444
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
4545
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
46-
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`.
46+
- :func:`sklearn.metrics.pairwise.euclidean_distances` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`;
47+
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`;
48+
- :func:`sklearn.metrics.pairwise.rbf_kernel` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`.
4749

4850
**Classes:**
4951

sklearn/decomposition/_base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from scipy import linalg
1010

1111
from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
12-
from ..utils._array_api import _add_to_diagonal, device, get_namespace
12+
from ..utils._array_api import _fill_or_add_to_diagonal, device, get_namespace
1313
from ..utils.validation import check_is_fitted
1414

1515

@@ -47,7 +47,7 @@ def get_covariance(self):
4747
xp.asarray(0.0, device=device(exp_var)),
4848
)
4949
cov = (components_.T * exp_var_diff) @ components_
50-
_add_to_diagonal(cov, self.noise_variance_, xp)
50+
_fill_or_add_to_diagonal(cov, self.noise_variance_, xp)
5151
return cov
5252

5353
def get_precision(self):
@@ -89,10 +89,10 @@ def get_precision(self):
8989
xp.asarray(0.0, device=device(exp_var)),
9090
)
9191
precision = components_ @ components_.T / self.noise_variance_
92-
_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
92+
_fill_or_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
9393
precision = components_.T @ linalg_inv(precision) @ components_
9494
precision /= -(self.noise_variance_**2)
95-
_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
95+
_fill_or_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
9696
return precision
9797

9898
@abstractmethod

sklearn/metrics/pairwise.py

+46-25
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,13 @@
2222
gen_even_slices,
2323
)
2424
from ..utils._array_api import (
25+
_fill_or_add_to_diagonal,
2526
_find_matching_floating_dtype,
2627
_is_numpy_namespace,
28+
_max_precision_float_dtype,
29+
_modify_in_place_if_numpy,
2730
get_namespace,
31+
get_namespace_and_device,
2832
)
2933
from ..utils._chunking import get_chunk_n_rows
3034
from ..utils._mask import _get_mask
@@ -335,13 +339,14 @@ def euclidean_distances(
335339
array([[1. ],
336340
[1.41421356]])
337341
"""
342+
xp, _ = get_namespace(X, Y)
338343
X, Y = check_pairwise_arrays(X, Y)
339344

340345
if X_norm_squared is not None:
341346
X_norm_squared = check_array(X_norm_squared, ensure_2d=False)
342347
original_shape = X_norm_squared.shape
343348
if X_norm_squared.shape == (X.shape[0],):
344-
X_norm_squared = X_norm_squared.reshape(-1, 1)
349+
X_norm_squared = xp.reshape(X_norm_squared, (-1, 1))
345350
if X_norm_squared.shape == (1, X.shape[0]):
346351
X_norm_squared = X_norm_squared.T
347352
if X_norm_squared.shape != (X.shape[0], 1):
@@ -354,7 +359,7 @@ def euclidean_distances(
354359
Y_norm_squared = check_array(Y_norm_squared, ensure_2d=False)
355360
original_shape = Y_norm_squared.shape
356361
if Y_norm_squared.shape == (Y.shape[0],):
357-
Y_norm_squared = Y_norm_squared.reshape(1, -1)
362+
Y_norm_squared = xp.reshape(Y_norm_squared, (1, -1))
358363
if Y_norm_squared.shape == (Y.shape[0], 1):
359364
Y_norm_squared = Y_norm_squared.T
360365
if Y_norm_squared.shape != (1, Y.shape[0]):
@@ -375,24 +380,25 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
375380
float32, norms needs to be recomputed on upcast chunks.
376381
TODO: use a float64 accumulator in row_norms to avoid the latter.
377382
"""
378-
if X_norm_squared is not None and X_norm_squared.dtype != np.float32:
379-
XX = X_norm_squared.reshape(-1, 1)
380-
elif X.dtype != np.float32:
381-
XX = row_norms(X, squared=True)[:, np.newaxis]
383+
xp, _, device_ = get_namespace_and_device(X, Y)
384+
if X_norm_squared is not None and X_norm_squared.dtype != xp.float32:
385+
XX = xp.reshape(X_norm_squared, (-1, 1))
386+
elif X.dtype != xp.float32:
387+
XX = row_norms(X, squared=True)[:, None]
382388
else:
383389
XX = None
384390

385391
if Y is X:
386392
YY = None if XX is None else XX.T
387393
else:
388-
if Y_norm_squared is not None and Y_norm_squared.dtype != np.float32:
389-
YY = Y_norm_squared.reshape(1, -1)
390-
elif Y.dtype != np.float32:
391-
YY = row_norms(Y, squared=True)[np.newaxis, :]
394+
if Y_norm_squared is not None and Y_norm_squared.dtype != xp.float32:
395+
YY = xp.reshape(Y_norm_squared, (1, -1))
396+
elif Y.dtype != xp.float32:
397+
YY = row_norms(Y, squared=True)[None, :]
392398
else:
393399
YY = None
394400

395-
if X.dtype == np.float32 or Y.dtype == np.float32:
401+
if X.dtype == xp.float32 or Y.dtype == xp.float32:
396402
# To minimize precision issues with float32, we compute the distance
397403
# matrix on chunks of X and Y upcast to float64
398404
distances = _euclidean_distances_upcast(X, XX, Y, YY)
@@ -401,14 +407,22 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
401407
distances = -2 * safe_sparse_dot(X, Y.T, dense_output=True)
402408
distances += XX
403409
distances += YY
404-
np.maximum(distances, 0, out=distances)
410+
411+
xp_zero = xp.asarray(0, device=device_, dtype=distances.dtype)
412+
distances = _modify_in_place_if_numpy(
413+
xp, xp.maximum, distances, xp_zero, out=distances
414+
)
405415

406416
# Ensure that distances between vectors and themselves are set to 0.0.
407417
# This may not be the case due to floating point rounding errors.
408418
if X is Y:
409-
np.fill_diagonal(distances, 0)
419+
_fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)
410420

411-
return distances if squared else np.sqrt(distances, out=distances)
421+
if squared:
422+
return distances
423+
424+
distances = _modify_in_place_if_numpy(xp, xp.sqrt, distances, out=distances)
425+
return distances
412426

413427

414428
@validate_params(
@@ -552,15 +566,20 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
552566
X and Y are upcast to float64 by chunks, which size is chosen to limit
553567
memory increase by approximately 10% (at least 10MiB).
554568
"""
569+
xp, _, device_ = get_namespace_and_device(X, Y)
555570
n_samples_X = X.shape[0]
556571
n_samples_Y = Y.shape[0]
557572
n_features = X.shape[1]
558573

559-
distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)
574+
distances = xp.empty((n_samples_X, n_samples_Y), dtype=xp.float32, device=device_)
560575

561576
if batch_size is None:
562-
x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
563-
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1
577+
x_density = (
578+
X.nnz / xp.prod(X.shape) if issparse(X) else xp.asarray(1, device=device_)
579+
)
580+
y_density = (
581+
Y.nnz / xp.prod(Y.shape) if issparse(Y) else xp.asarray(1, device=device_)
582+
)
564583

565584
# Allow 10% more memory than X, Y and the distance matrix take (at
566585
# least 10MiB)
@@ -580,15 +599,15 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
580599
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
581600
# xd=x_density and yd=y_density
582601
tmp = (x_density + y_density) * n_features
583-
batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
602+
batch_size = (-tmp + xp.sqrt(tmp**2 + 4 * maxmem)) / 2
584603
batch_size = max(int(batch_size), 1)
585604

586605
x_batches = gen_batches(n_samples_X, batch_size)
587-
606+
xp_max_float = _max_precision_float_dtype(xp=xp, device=device_)
588607
for i, x_slice in enumerate(x_batches):
589-
X_chunk = X[x_slice].astype(np.float64)
608+
X_chunk = xp.astype(X[x_slice], xp_max_float)
590609
if XX is None:
591-
XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis]
610+
XX_chunk = row_norms(X_chunk, squared=True)[:, None]
592611
else:
593612
XX_chunk = XX[x_slice]
594613

@@ -601,17 +620,17 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
601620
d = distances[y_slice, x_slice].T
602621

603622
else:
604-
Y_chunk = Y[y_slice].astype(np.float64)
623+
Y_chunk = xp.astype(Y[y_slice], xp_max_float)
605624
if YY is None:
606-
YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :]
625+
YY_chunk = row_norms(Y_chunk, squared=True)[None, :]
607626
else:
608627
YY_chunk = YY[:, y_slice]
609628

610629
d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
611630
d += XX_chunk
612631
d += YY_chunk
613632

614-
distances[x_slice, y_slice] = d.astype(np.float32, copy=False)
633+
distances[x_slice, y_slice] = xp.astype(d, xp.float32, copy=False)
615634

616635
return distances
617636

@@ -1549,13 +1568,15 @@ def rbf_kernel(X, Y=None, gamma=None):
15491568
array([[0.71..., 0.51...],
15501569
[0.51..., 0.71...]])
15511570
"""
1571+
xp, _ = get_namespace(X, Y)
15521572
X, Y = check_pairwise_arrays(X, Y)
15531573
if gamma is None:
15541574
gamma = 1.0 / X.shape[1]
15551575

15561576
K = euclidean_distances(X, Y, squared=True)
15571577
K *= -gamma
1558-
np.exp(K, K) # exponentiate K in-place
1578+
# exponentiate K in-place when using numpy
1579+
K = _modify_in_place_if_numpy(xp, xp.exp, K, out=K)
15591580
return K
15601581

15611582

sklearn/metrics/tests/test_common.py

+4
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@
5555
additive_chi2_kernel,
5656
chi2_kernel,
5757
cosine_similarity,
58+
euclidean_distances,
5859
paired_cosine_distances,
60+
rbf_kernel,
5961
)
6062
from sklearn.preprocessing import LabelBinarizer
6163
from sklearn.utils import shuffle
@@ -2014,6 +2016,8 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20142016
mean_gamma_deviance: [check_array_api_regression_metric],
20152017
max_error: [check_array_api_regression_metric],
20162018
chi2_kernel: [check_array_api_metric_pairwise],
2019+
euclidean_distances: [check_array_api_metric_pairwise],
2020+
rbf_kernel: [check_array_api_metric_pairwise],
20172021
}
20182022

20192023

sklearn/utils/_array_api.py

+68-15
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,15 @@ def __eq__(self, other):
302302
def isdtype(self, dtype, kind):
303303
return isdtype(dtype, kind, xp=self._namespace)
304304

305+
def maximum(self, x1, x2):
306+
# TODO: Remove when `maximum` is made compatible in `array_api_compat`,
307+
# based on the `2023.12` specification.
308+
# https://github.com/data-apis/array-api-compat/issues/127
309+
x1_np = _convert_to_numpy(x1, xp=self._namespace)
310+
x2_np = _convert_to_numpy(x2, xp=self._namespace)
311+
x_max = numpy.maximum(x1_np, x2_np)
312+
return self._namespace.asarray(x_max, device=device(x1, x2))
313+
305314

306315
def _check_device_cpu(device): # noqa
307316
if device not in {"cpu", None}:
@@ -566,7 +575,28 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):
566575

567576

568577
def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)):
569-
"""Combination into one single function of `get_namespace` and `device`."""
578+
"""Combination into one single function of `get_namespace` and `device`.
579+
580+
Parameters
581+
----------
582+
*array_list : array objects
583+
Array objects.
584+
remove_none : bool, default=True
585+
Whether to ignore None objects passed in arrays.
586+
remove_types : tuple or list, default=(str,)
587+
Types to ignore in the arrays.
588+
589+
Returns
590+
-------
591+
namespace : module
592+
Namespace shared by array objects. If any of the `arrays` are not arrays,
593+
the namespace defaults to NumPy.
594+
is_array_api_compliant : bool
595+
True if the arrays are containers that implement the Array API spec.
596+
Always False when array_api_dispatch=False.
597+
device : device
598+
`device` object (see the "Device Support" section of the array API spec).
599+
"""
570600
array_list = _remove_non_arrays(
571601
*array_list, remove_none=remove_none, remove_types=remove_types
572602
)
@@ -592,21 +622,36 @@ def _expit(X, xp=None):
592622
return 1.0 / (1.0 + xp.exp(-X))
593623

594624

595-
def _add_to_diagonal(array, value, xp):
596-
# Workaround for the lack of support for xp.reshape(a, shape, copy=False) in
597-
# numpy.array_api: https://github.com/numpy/numpy/issues/23410
598-
value = xp.asarray(value, dtype=array.dtype)
599-
if _is_numpy_namespace(xp):
600-
array_np = numpy.asarray(array)
601-
array_np.flat[:: array.shape[0] + 1] += value
602-
return xp.asarray(array_np)
603-
elif value.ndim == 1:
604-
for i in range(array.shape[0]):
605-
array[i, i] += value[i]
625+
def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False):
626+
"""Implementation to facilitate adding or assigning specified values to the
627+
diagonal of a 2-d array.
628+
629+
If ``add_value`` is `True` then the values will be added to the diagonal
630+
elements otherwise the values will be assigned to the diagonal elements.
631+
By default, ``add_value`` is set to `True. This is currently only
632+
supported for 2-d arrays.
633+
634+
The implementation is taken from the `numpy.fill_diagonal` function:
635+
https://github.com/numpy/numpy/blob/v2.0.0/numpy/lib/_index_tricks_impl.py#L799-L929
636+
"""
637+
if array.ndim != 2:
638+
raise ValueError(
639+
f"array should be 2-d. Got array with shape {tuple(array.shape)}"
640+
)
641+
642+
value = xp.asarray(value, dtype=array.dtype, device=device(array))
643+
end = None
644+
# Explicit, fast formula for the common case. For 2-d arrays, we
645+
# accept rectangular ones.
646+
step = array.shape[1] + 1
647+
if not wrap:
648+
end = array.shape[1] * array.shape[1]
649+
650+
array_flat = xp.reshape(array, (-1,))
651+
if add_value:
652+
array_flat[:end:step] += value
606653
else:
607-
# scalar value
608-
for i in range(array.shape[0]):
609-
array[i, i] += value
654+
array_flat[:end:step] = value
610655

611656

612657
def _max_precision_float_dtype(xp, device):
@@ -1000,3 +1045,11 @@ def _count_nonzero(X, xp, device, axis=None, sample_weight=None):
10001045

10011046
zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
10021047
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)
1048+
1049+
1050+
def _modify_in_place_if_numpy(xp, func, *args, out=None, **kwargs):
1051+
if _is_numpy_namespace(xp):
1052+
func(*args, out=out, **kwargs)
1053+
else:
1054+
out = func(*args, **kwargs)
1055+
return out

0 commit comments

Comments
 (0)
0