8000 Remove hardcoded device choice in _weighted_sum · scikit-learn/scikit-learn@913faca · GitHub
[go: up one dir, main page]

Skip to content

Commit 913faca

Browse files
committed
Remove hardcoded device choice in _weighted_sum
Some Array API compatible libraries do not have a device called 'cpu'. Instead we try and detect the lib+device combination that does not support float64.
1 parent a31e108 commit 913faca

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

sklearn/utils/_array_api.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,12 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
453453
# with lazy Array API implementations. See:
454454
# https://github.com/data-apis/array-api/issues/642
455455
if xp is None:
456-
xp, _ = get_namespace(sample_score)
456+
# Make sure the scores and weights belong to the same namespace
457+
if sample_weight is not None:
458+
xp, _ = get_namespace(sample_score, sample_weight)
459+
else:
460+
xp, _ = get_namespace(sample_score)
461+
457462
if normalize and _is_numpy_namespace(xp):
458463
sample_score_np = numpy.asarray(sample_score)
459464
if sample_weight is not None:
@@ -463,14 +468,17 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
463468
return float(numpy.average(sample_score_np, weights=sample_weight_np))
464469

465470
if not xp.isdtype(sample_score.dtype, "real floating"):
466-
# We move to cpu device ahead of time since certain devices may not support
467-
# float64, but we want the same precision for all devices and namespaces.
468-
sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
471+
# The MPS device does not support float64
472+
if (
473+
xp.__name__ in {"array_api_compat.torch", "torch"}
474+
and device(sample_score).type == "mps"
475+
):
476+
sample_score = xp.astype(sample_score, xp.float32)
477+
else:
478+
sample_score = xp.astype(sample_score, xp.float64)
469479

470480
if sample_weight is not None:
471481
sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype)
472-
if not xp.isdtype(sample_weight.dtype, "real floating"):
473-
sample_weight = xp.astype(sample_weight, xp.float64)
474482

475483
if normalize:
476484
if sample_weight is not None:

0 commit comments

Comments
 (0)
0