@@ -453,7 +453,12 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
453
453
# with lazy Array API implementations. See:
454
454
# https://github.com/data-apis/array-api/issues/642
455
455
if xp is None :
456
- xp , _ = get_namespace (sample_score )
456
+ # Make sure the scores and weights belong to the same namespace
457
+ if sample_weight is not None :
458
+ xp , _ = get_namespace (sample_score , sample_weight )
459
+ else :
460
+ xp , _ = get_namespace (sample_score )
461
+
457
462
if normalize and _is_numpy_namespace (xp ):
458
463
sample_score_np = numpy .asarray (sample_score )
459
464
if sample_weight is not None :
@@ -463,14 +468,17 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
463
468
return float (numpy .average (sample_score_np , weights = sample_weight_np ))
464
469
465
470
if not xp .isdtype (sample_score .dtype , "real floating" ):
466
- # We move to cpu device ahead of time since certain devices may not support
467
- # float64, but we want the same precision for all devices and namespaces.
468
- sample_score = xp .astype (xp .asarray (sample_score , device = "cpu" ), xp .float64 )
471
+ # The MPS device does not support float64
472
+ if (
473
+ xp .__name__ in {"array_api_compat.torch" , "torch" }
474
+ and device (sample_score ).type == "mps"
475
+ ):
476
+ sample_score = xp .astype (sample_score , xp .float32 )
477
+ else :
478
+ sample_score = xp .astype (sample_score , xp .float64 )
469
479
470
480
if sample_weight is not None :
471
481
sample_weight = xp .asarray (sample_weight , dtype = sample_score .dtype )
472
- if not xp .isdtype (sample_weight .dtype , "real floating" ):
473
- sample_weight = xp .astype (sample_weight , xp .float64 )
474
482
475
483
if normalize :
476
484
if sample_weight is not None :
0 commit comments