[go: up one dir, main page]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Array API support for euclidean_distances and rbf_kernel #29433

Merged
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
e611758
ENH Array API support for euclidean_distances and rbf_kernel
OmarManzoor Jul 8, 2024
ba092bd
Fix for cupy.array_api
OmarManzoor Jul 8, 2024
ab46fec
Add alternative for np.dot and add rbf kernel
OmarManzoor Jul 9, 2024
191e29b
Remove tensordot
OmarManzoor Jul 9, 2024
7fa28e2
Merge branch 'main' into array_api_euclidean_distances_rbf_kernel
OmarManzoor Jul 9, 2024
a7dc716
Merge branch 'main' into array_api_euclidean_distances_rbf_kernel
ogrisel Jul 9, 2024
5a4f031
Retain in-place operations where possible
OmarManzoor Jul 9, 2024
fcaaa64
Add changelog
OmarManzoor Jul 9, 2024
3aae9d2
Add pr number
OmarManzoor Jul 9, 2024
d07865f
Improve _xp_method_has_out
OmarManzoor Jul 9, 2024
f8f75b1
Fix for dtype in zero definition
OmarManzoor Jul 9, 2024
c2290fb
Merge branch 'main' into array_api_euclidean_distances_rbf_kernel
OmarManzoor Jul 10, 2024
41118c7
Updates: according to PR suggestions
OmarManzoor Jul 10, 2024
6deb173
Minor change in docstring
OmarManzoor Jul 10, 2024
3145042
Fix docstring
OmarManzoor Jul 10, 2024
8960391
Handle mps float32 separately
OmarManzoor Jul 10, 2024
e67ddd3
Remove maximum from array api wrapper
OmarManzoor Jul 10, 2024
a1715f4
Add device float support section
OmarManzoor Jul 10, 2024
d43931a
Add the maximum function back along with a test
OmarManzoor Jul 10, 2024
8318e56
Updates: further suggestions
OmarManzoor Jul 10, 2024
074a0f1
Add tests for fill_or_add_to_diagonal
OmarManzoor Jul 10, 2024
363aaef
Merge branch 'main' into array_api_euclidean_distances_rbf_kernel
OmarManzoor Jul 10, 2024
6fc5a0b
Minor change in docstring
OmarManzoor Jul 10, 2024
ead6ea8
Further updates: based on PR review
OmarManzoor Jul 11, 2024
63f1242
Further updates: based on PR review
OmarManzoor Jul 11, 2024
3499bcf
Update TODO statement
OmarManzoor Jul 11, 2024
2eec496
Merge branch 'main' into array_api_euclidean_distances_rbf_kernel
OmarManzoor Jul 11, 2024
68b1467
Improve error message.
ogrisel Jul 11, 2024
7073aad
Merge branch 'main' into array_api_euclidean_distances_rbf_kernel
OmarManzoor Jul 11, 2024
526d932
Fix linting
OmarManzoor Jul 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ Metrics
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
- :func:`sklearn.metrics.pairwise.chi2_kernel`
- :func:`sklearn.metrics.pairwise.cosine_similarity`
- :func:`sklearn.metrics.pairwise.euclidean_distances` (see :ref:`device_support_for_float64`)
- :func:`sklearn.metrics.pairwise.paired_cosine_distances`
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
- :func:`sklearn.metrics.r2_score`
- :func:`sklearn.metrics.zero_one_loss`

Expand Down Expand Up @@ -172,6 +174,8 @@ automatically skipped. Therefore it's important to run the tests with the
pip install array-api-compat # and other libraries as needed
pytest -k "array_api" -v

.. _mps_support:

Note on MPS device support
--------------------------

Expand All @@ -191,3 +195,17 @@ To enable the MPS support in PyTorch, set the environment variable

At the time of writing all scikit-learn tests should pass, however, the
computational speed is not necessarily better than with the CPU device.

.. _device_support_for_float64:

Note on device support for ``float64``
--------------------------------------

Certain operations within scikit-learn will automatically perform operations
on floating-point values with `float64` precision to prevent overflows and ensure
correctness (e.g., :func:`metrics.pairwise.euclidean_distances`). However,
certain combinations of array namespaces and devices, such as `PyTorch on MPS`
(see :ref:`mps_support`) do not support the `float64` data type. In these cases,
scikit-learn will revert to using the `float32` data type instead. This can result in
different behavior (typically numerically unstable results) compared to not using array
API dispatching or using a device with `float64` support.
6 changes: 4 additions & 2 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ See :ref:`array_api` for more details.
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
- :func:`sklearn.metrics.pairwise.cosine_similarity` :pr:`29014` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`.
- :func:`sklearn.metrics.pairwise.euclidean_distances` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`;
- :func:`sklearn.metrics.pairwise.paired_cosine_distances` :pr:`29112` by :user:`Edoardo Abati <EdAbati>`;
- :func:`sklearn.metrics.pairwise.rbf_kernel` :pr:`29433` by :user:`Omar Salman <OmarManzoor>`.

**Classes:**

Expand Down Expand Up @@ -71,7 +73,7 @@ more details.
:class:`ensemble.StackingRegressor` now support metadata routing and pass
``**fit_params`` to the underlying estimators via their `fit` methods.
:pr:`28701` by :user:`Stefanie Senger <StefanieSenger>`.

- |Feature| :class:`compose.TransformedTargetRegressor` now supports metadata
routing in its `fit` and `predict` methods and routes the corresponding
params to the underlying regressor.
Expand Down
8 changes: 4 additions & 4 deletions sklearn/decomposition/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy import linalg

from ..base import BaseEstimator, ClassNamePrefixFeaturesOutMixin, TransformerMixin
from ..utils._array_api import _add_to_diagonal, device, get_namespace
from ..utils._array_api import _fill_or_add_to_diagonal, device, get_namespace
from ..utils.validation import check_is_fitted


Expand Down Expand Up @@ -47,7 +47,7 @@ def get_covariance(self):
xp.asarray(0.0, device=device(exp_var)),
)
cov = (components_.T * exp_var_diff) @ components_
_add_to_diagonal(cov, self.noise_variance_, xp)
_fill_or_add_to_diagonal(cov, self.noise_variance_, xp)
return cov

def get_precision(self):
Expand Down Expand Up @@ -89,10 +89,10 @@ def get_precision(self):
xp.asarray(0.0, device=device(exp_var)),
)
precision = components_ @ components_.T / self.noise_variance_
_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
_fill_or_add_to_diagonal(precision, 1.0 / exp_var_diff, xp)
precision = components_.T @ linalg_inv(precision) @ components_
precision /= -(self.noise_variance_**2)
_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
_fill_or_add_to_diagonal(precision, 1.0 / self.noise_variance_, xp)
return precision

@abstractmethod
Expand Down
74 changes: 50 additions & 24 deletions sklearn/metrics/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
gen_even_slices,
)
from ..utils._array_api import (
_fill_or_add_to_diagonal,
_find_matching_floating_dtype,
_is_numpy_namespace,
_max_precision_float_dtype,
_modify_in_place_if_numpy,
get_namespace,
get_namespace_and_device,
)
from ..utils._chunking import get_chunk_n_rows
from ..utils._mask import _get_mask
Expand Down Expand Up @@ -335,13 +339,14 @@ def euclidean_distances(
array([[1. ],
[1.41421356]])
"""
xp, _ = get_namespace(X, Y)
X, Y = check_pairwise_arrays(X, Y)

if X_norm_squared is not None:
X_norm_squared = check_array(X_norm_squared, ensure_2d=False)
original_shape = X_norm_squared.shape
if X_norm_squared.shape == (X.shape[0],):
X_norm_squared = X_norm_squared.reshape(-1, 1)
X_norm_squared = xp.reshape(X_norm_squared, (-1, 1))
if X_norm_squared.shape == (1, X.shape[0]):
X_norm_squared = X_norm_squared.T
if X_norm_squared.shape != (X.shape[0], 1):
Expand All @@ -354,7 +359,7 @@ def euclidean_distances(
Y_norm_squared = check_array(Y_norm_squared, ensure_2d=False)
original_shape = Y_norm_squared.shape
if Y_norm_squared.shape == (Y.shape[0],):
Y_norm_squared = Y_norm_squared.reshape(1, -1)
Y_norm_squared = xp.reshape(Y_norm_squared, (1, -1))
if Y_norm_squared.shape == (Y.shape[0], 1):
Y_norm_squared = Y_norm_squared.T
if Y_norm_squared.shape != (1, Y.shape[0]):
Expand All @@ -375,24 +380,30 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
float32, norms needs to be recomputed on upcast chunks.
TODO: use a float64 accumulator in row_norms to avoid the latter.
"""
if X_norm_squared is not None and X_norm_squared.dtype != np.float32:
XX = X_norm_squared.reshape(-1, 1)
elif X.dtype != np.float32:
XX = row_norms(X, squared=True)[:, np.newaxis]
xp, _, device_ = get_namespace_and_device(X, Y)
if X_norm_squared is not None and X_norm_squared.dtype != xp.float32:
XX = xp.reshape(X_norm_squared, (-1, 1))
elif X.dtype != xp.float32:
XX = row_norms(X, squared=True)[:, None]
else:
XX = None

if Y is X:
YY = None if XX is None else XX.T
else:
if Y_norm_squared is not None and Y_norm_squared.dtype != np.float32:
YY = Y_norm_squared.reshape(1, -1)
elif Y.dtype != np.float32:
YY = row_norms(Y, squared=True)[np.newaxis, :]
if Y_norm_squared is not None and Y_norm_squared.dtype != xp.float32:
YY = xp.reshape(Y_norm_squared, (1, -1))
elif Y.dtype != xp.float32:
YY = row_norms(Y, squared=True)[None, :]
else:
YY = None

if X.dtype == np.float32 or Y.dtype == np.float32:
if _max_precision_float_dtype(xp=xp, device=device_) == xp.float32:
# special case for mps devices which don't support float64.
X_r = X[:, None]
Y_r = Y[None, :]
distances = xp.sum((X_r - Y_r) ** 2, axis=2)
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved
elif X.dtype == xp.float32 or Y.dtype == xp.float32:
# To minimize precision issues with float32, we compute the distance
# matrix on chunks of X and Y upcast to float64
distances = _euclidean_distances_upcast(X, XX, Y, YY)
Expand All @@ -401,14 +412,22 @@ def _euclidean_distances(X, Y, X_norm_squared=None, Y_norm_squared=None, squared
distances = -2 * safe_sparse_dot(X, Y.T, dense_output=True)
distances += XX
distances += YY
np.maximum(distances, 0, out=distances)

xp_zero = xp.asarray(0, device=device_, dtype=distances.dtype)
distances = _modify_in_place_if_numpy(
xp, xp.maximum, distances, xp_zero, out=distances
)

# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
if X is Y:
np.fill_diagonal(distances, 0)
_fill_or_add_to_diagonal(distances, 0, xp=xp, add_value=False)

if squared:
return distances

return distances if squared else np.sqrt(distances, out=distances)
distances = _modify_in_place_if_numpy(xp, xp.sqrt, distances, out=distances)
return distances


@validate_params(
Expand Down Expand Up @@ -552,15 +571,20 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
X and Y are upcast to float64 by chunks, which size is chosen to limit
memory increase by approximately 10% (at least 10MiB).
"""
xp, _, device_ = get_namespace_and_device(X, Y)
n_samples_X = X.shape[0]
n_samples_Y = Y.shape[0]
n_features = X.shape[1]

distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32)
distances = xp.empty((n_samples_X, n_samples_Y), dtype=xp.float32, device=device_)

if batch_size is None:
x_density = X.nnz / np.prod(X.shape) if issparse(X) else 1
y_density = Y.nnz / np.prod(Y.shape) if issparse(Y) else 1
x_density = (
X.nnz / np.prod(X.shape) if issparse(X) else xp.asarray(1, device=device_)
)
y_density = (
Y.nnz / np.prod(Y.shape) if issparse(Y) else xp.asarray(1, device=device_)
)
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved

# Allow 10% more memory than X, Y and the distance matrix take (at
# least 10MiB)
Expand All @@ -580,15 +604,15 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
# Hence x² + (xd+yd)kx = M, where x=batch_size, k=n_features, M=maxmem
# xd=x_density and yd=y_density
tmp = (x_density + y_density) * n_features
batch_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2
batch_size = (-tmp + xp.sqrt(tmp**2 + 4 * maxmem)) / 2
batch_size = max(int(batch_size), 1)

x_batches = gen_batches(n_samples_X, batch_size)

for i, x_slice in enumerate(x_batches):
X_chunk = X[x_slice].astype(np.float64)
X_chunk = xp.astype(X[x_slice], xp.float64)
if XX is None:
XX_chunk = row_norms(X_chunk, squared=True)[:, np.newaxis]
XX_chunk = row_norms(X_chunk, squared=True)[:, None]
else:
XX_chunk = XX[x_slice]

Expand All @@ -601,17 +625,17 @@ def _euclidean_distances_upcast(X, XX=None, Y=None, YY=None, batch_size=None):
d = distances[y_slice, x_slice].T

else:
Y_chunk = Y[y_slice].astype(np.float64)
Y_chunk = xp.astype(Y[y_slice], xp.float64)
if YY is None:
YY_chunk = row_norms(Y_chunk, squared=True)[np.newaxis, :]
YY_chunk = row_norms(Y_chunk, squared=True)[None, :]
else:
YY_chunk = YY[:, y_slice]

d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True)
d += XX_chunk
d += YY_chunk

distances[x_slice, y_slice] = d.astype(np.float32, copy=False)
distances[x_slice, y_slice] = xp.astype(d, xp.float32, copy=False)

return distances

Expand Down Expand Up @@ -1549,13 +1573,15 @@ def rbf_kernel(X, Y=None, gamma=None):
array([[0.71..., 0.51...],
[0.51..., 0.71...]])
"""
xp, _ = get_namespace(X, Y)
X, Y = check_pairwise_arrays(X, Y)
if gamma is None:
gamma = 1.0 / X.shape[1]

K = euclidean_distances(X, Y, squared=True)
K *= -gamma
np.exp(K, K) # exponentiate K in-place
# exponentiate K in-place when using numpy
K = _modify_in_place_if_numpy(xp, xp.exp, K, out=K)
return K


Expand Down
4 changes: 4 additions & 0 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
additive_chi2_kernel,
chi2_kernel,
cosine_similarity,
euclidean_distances,
paired_cosine_distances,
rbf_kernel,
)
from sklearn.preprocessing import LabelBinarizer
from sklearn.utils import shuffle
Expand Down Expand Up @@ -2014,6 +2016,8 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
mean_gamma_deviance: [check_array_api_regression_metric],
max_error: [check_array_api_regression_metric],
chi2_kernel: [check_array_api_metric_pairwise],
euclidean_distances: [check_array_api_metric_pairwise],
rbf_kernel: [check_array_api_metric_pairwise],
}


Expand Down
78 changes: 63 additions & 15 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,12 @@ def __eq__(self, other):
def isdtype(self, dtype, kind):
return isdtype(dtype, kind, xp=self._namespace)

def maximum(self, x1, x2):
x1_np = _convert_to_numpy(x1, xp=self._namespace)
x2_np = _convert_to_numpy(x2, xp=self._namespace)
x_max = numpy.maximum(x1_np, x2_np)
return self._namespace.asarray(x_max, device=device(x1, x2))
OmarManzoor marked this conversation as resolved.
Show resolved Hide resolved


def _check_device_cpu(device): # noqa
if device not in {"cpu", None}:
Expand Down Expand Up @@ -566,7 +572,28 @@ def get_namespace(*arrays, remove_none=True, remove_types=(str,), xp=None):


def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,)):
"""Combination into one single function of `get_namespace` and `device`."""
"""Combination into one single function of `get_namespace` and `device`.

Parameters
----------
*array_list : array objects
Array objects.
remove_none : bool, default=True
Whether to ignore None objects passed in arrays.
remove_types : tuple or list, default=(str,)
Types to ignore in the arrays.

Returns
-------
namespace : module
Namespace shared by array objects. If any of the `arrays` are not arrays,
the namespace defaults to NumPy.
is_array_api_compliant : bool
True if the arrays are containers that implement the Array API spec.
Always False when array_api_dispatch=False.
device : device
`device` object (see the "Device Support" section of the array API spec).
"""
array_list = _remove_non_arrays(
*array_list, remove_none=remove_none, remove_types=remove_types
)
Expand All @@ -592,21 +619,34 @@ def _expit(X, xp=None):
return 1.0 / (1.0 + xp.exp(-X))


def _add_to_diagonal(array, value, xp):
# Workaround for the lack of support for xp.reshape(a, shape, copy=False) in
# numpy.array_api: https://github.com/numpy/numpy/issues/23410
value = xp.asarray(value, dtype=array.dtype)
if _is_numpy_namespace(xp):
array_np = numpy.asarray(array)
array_np.flat[:: array.shape[0] + 1] += value
return xp.asarray(array_np)
elif value.ndim == 1:
for i in range(array.shape[0]):
array[i, i] += value[i]
def _fill_or_add_to_diagonal(array, value, xp, add_value=True, wrap=False):
"""Implementation to facilitate adding or assigning specified values to the
diagonal of a 2-d array.

If ``add_value`` is `True` then the values will be added to the diagonal
elements otherwise the values will be assigned to the diagonal elements.
By default, ``add_value`` is set to `True. This is currently only
supported for 2-d arrays.

The implementation is taken from the `numpy.fill_diagonal` function:
https://github.com/numpy/numpy/blob/v2.0.0/numpy/lib/_index_tricks_impl.py#L799-L929
"""
if array.ndim != 2:
raise ValueError("array should be 2-d")
ogrisel marked this conversation as resolved.
Show resolved Hide resolved

value = xp.asarray(value, dtype=array.dtype, device=device(array))
end = None
# Explicit, fast formula for the common case. For 2-d arrays, we
# accept rectangular ones.
step = array.shape[1] + 1
if not wrap:
end = array.shape[1] * array.shape[1]

array_flat = xp.reshape(array, (-1,))
if add_value:
array_flat[:end:step] += value
else:
# scalar value
for i in range(array.shape[0]):
array[i, i] += value
array_flat[:end:step] = value


def _max_precision_float_dtype(xp, device):
Expand Down Expand Up @@ -1000,3 +1040,11 @@ def _count_nonzero(X, xp, device, axis=None, sample_weight=None):

zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)


def _modify_in_place_if_numpy(xp, func, *args, out=None, **kwargs):
if _is_numpy_namespace(xp):
func(*args, out=out, **kwargs)
else:
out = func(*args, **kwargs)
return out
Loading
Loading