@@ -182,6 +182,16 @@ def supported_float_dtypes(xp):
182
182
return (xp .float64 , xp .float32 )
183
183
184
184
185
+ def max_precision_float_dtype (xp , device ):
186
+ """Highest precision float dtype support by namespace and device"""
187
+ # temporary hack while waiting for a proper inspection API, see:
188
+ # https://github.com/data-apis/array-api/issues/640
189
+ if xp .__name__ in {"array_api_compat.torch" , "torch" } and device .type == "mps" :
190
+ return xp .float32
191
+ else :
192
+ return xp .float64
193
+
194
+
185
195
class _ArrayAPIWrapper :
186
196
"""sklearn specific Array API compatibility wrapper
187
197
@@ -468,17 +478,14 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
468
478
return float (numpy .average (sample_score_np , weights = sample_weight_np ))
469
479
470
480
if not xp .isdtype (sample_score .dtype , "real floating" ):
471
- # The MPS device does not support float64
472
- if (
473
- xp .__name__ in {"array_api_compat.torch" , "torch" }
474
- and device (sample_score ).type == "mps"
475
- ):
476
- sample_score = xp .astype (sample_score , xp .float32 )
477
- else :
478
- sample_score = xp .astype (sample_score , xp .float64 )
481
+ sample_score = xp .astype (
482
+ sample_score , max_precision_float_dtype (xp , device (sample_score ))
483
+ )
479
484
480
485
if sample_weight is not None :
481
- sample_weight = xp .asarray (sample_weight , dtype = sample_score .dtype )
486
+ sample_weight = xp .asarray (
487
+ sample_weight , dtype = sample_score .dtype , device = device (sample_score )
488
+ )
482
489
483
490
if normalize :
484
491
if sample_weight is not None :
0 commit comments