8000 FIX dtype handling regression in pairwise distance computation (#29746) · scikit-learn/scikit-learn@be8e28d · GitHub
[go: up one dir, main page]

Skip to content

Commit be8e28d

Browse files
authored
FIX dtype handling regression in pairwise distance computation (#29746)
1 parent 68d8c2c commit be8e28d

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

sklearn/metrics/pairwise.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
from .. import config_context
1717
from ..exceptions import DataConversionWarning
1818
from ..preprocessing import normalize
19-
from ..utils import (
20-
check_array,
21-
gen_batches,
22-
gen_even_slices,
23-
)
19+
from ..utils import check_array, gen_batches, gen_even_slices
2420
from ..utils._array_api import (
2521
_fill_or_add_to_diagonal,
2622
_find_matching_floating_dtype,
@@ -1169,7 +1165,11 @@ def cosine_distances(X, Y=None):
11691165
# TODO: remove the xp.asarray calls once the following is fixed:
11701166
# https://github.com/data-apis/array-api-compat/issues/177
11711167
device_ = device(S)
1172-
S = xp.clip(S, xp.asarray(0.0, device=device_), xp.asarray(2.0, device=device_))
1168+
S = xp.clip(
1169+
S,
1170+
xp.asarray(0.0, device=device_, dtype=S.dtype),
1171+
xp.asarray(2.0, device=device_, dtype=S.dtype),
1172+
)
11731173
if X is Y or Y is None:
11741174
# Ensure that distances between vectors and themselves are set to 0.0.
11751175
# This may not be the case due to floating point rounding errors.

0 commit comments

Comments
 (0)
0