diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 7a56387f61ea8..c3f0155408cf6 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -16,11 +16,7 @@ from .. import config_context from ..exceptions import DataConversionWarning from ..preprocessing import normalize -from ..utils import ( - check_array, - gen_batches, - gen_even_slices, -) +from ..utils import check_array, gen_batches, gen_even_slices from ..utils._array_api import ( _fill_or_add_to_diagonal, _find_matching_floating_dtype, @@ -1169,7 +1165,11 @@ def cosine_distances(X, Y=None): # TODO: remove the xp.asarray calls once the following is fixed: # https://github.com/data-apis/array-api-compat/issues/177 device_ = device(S) - S = xp.clip(S, xp.asarray(0.0, device=device_), xp.asarray(2.0, device=device_)) + S = xp.clip( + S, + xp.asarray(0.0, device=device_, dtype=S.dtype), + xp.asarray(2.0, device=device_, dtype=S.dtype), + ) if X is Y or Y is None: # Ensure that distances between vectors and themselves are set to 0.0. # This may not be the case due to floating point rounding errors.