8000 Factor out max float precision determination · scikit-learn/scikit-learn@0576ec2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0576ec2

Browse files
committed
Factor out max float precision determination
1 parent 913faca commit 0576ec2

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

sklearn/utils/_array_api.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,16 @@ def supported_float_dtypes(xp):
182182
return (xp.float64, xp.float32)
183183

184184

185+
def max_precision_float_dtype(xp, device):
186+
"""Highest precision float dtype support by namespace and device"""
187+
# temporary hack while waiting for a proper inspection API, see:
188+
# https://github.com/data-apis/array-api/issues/640
189+
if xp.__name__ in {"array_api_compat.torch", "torch"} and device.type == "mps":
190+
return xp.float32
191+
else:
192+
return xp.float64
193+
194+
185195
class _ArrayAPIWrapper:
186196
"""sklearn specific Array API compatibility wrapper
187197
@@ -468,17 +478,14 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
468478
return float(numpy.average(sample_score_np, weights=sample_weight_np))
469479

470480
if not xp.isdtype(sample_score.dtype, "real floating"):
471-
# The MPS device does not support float64
472-
if (
473-
xp.__name__ in {"array_api_compat.torch", "torch"}
474-
and device(sample_score).type == "mps"
475-
):
476-
sample_score = xp.astype(sample_score, xp.float32)
477-
else:
478-
sample_score = xp.astype(sample_score, xp.float64)
481+
sample_score = xp.astype(
482+
sample_score, max_precision_float_dtype(xp, device(sample_score))
483+
)
479484

480485
if sample_weight is not None:
481-
sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype)
486+
sample_weight = xp.asarray(
487+
sample_weight, dtype=sample_score.dtype, device=device(sample_score)
488+
)
482489

483490
if normalize:
484491
if sample_weight is not None:

0 commit comments

Comments
 (0)
0