From 913facaf7b2603688e552ad90880dc8fbfde771e Mon Sep 17 00:00:00 2001 From: Tim Head Date: Wed, 30 Aug 2023 13:43:37 +0200 Subject: [PATCH 1/2] Remove hardcoded device choice in _weighted_sum Some Array API compatible libraries do not have a device called 'cpu'. Instead we try and detect the lib+device combination that does not support float64. --- sklearn/utils/_array_api.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index 24534faa931e8..c52ada88ef713 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -453,7 +453,12 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): # with lazy Array API implementations. See: # https://github.com/data-apis/array-api/issues/642 if xp is None: - xp, _ = get_namespace(sample_score) + # Make sure the scores and weights belong to the same namespace + if sample_weight is not None: + xp, _ = get_namespace(sample_score, sample_weight) + else: + xp, _ = get_namespace(sample_score) + if normalize and _is_numpy_namespace(xp): sample_score_np = numpy.asarray(sample_score) if sample_weight is not None: @@ -463,14 +468,17 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): return float(numpy.average(sample_score_np, weights=sample_weight_np)) if not xp.isdtype(sample_score.dtype, "real floating"): - # We move to cpu device ahead of time since certain devices may not support - # float64, but we want the same precision for all devices and namespaces. - sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64) + # The MPS device does not support float64 + if ( + xp.__name__ in {"array_api_compat.torch", "torch"} + and device(sample_score).type == "mps" + ): + sample_score = xp.astype(sample_score, xp.float32) + else: + sample_score = xp.astype(sample_score, xp.float64) if sample_weight is not None: sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype) - if not xp.isdtype(sample_weight.dtype, "real floating"): - sample_weight = xp.astype(sample_weight, xp.float64) if normalize: if sample_weight is not None: From 0576ec28e011774010f26b94167341727b0eabd8 Mon Sep 17 00:00:00 2001 From: Tim Head Date: Thu, 7 Sep 2023 17:40:38 +0200 Subject: [PATCH 2/2] Factor out max float precision determination --- sklearn/utils/_array_api.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/sklearn/utils/_array_api.py b/sklearn/utils/_array_api.py index c52ada88ef713..b3a5ca0ebed8f 100644 --- a/sklearn/utils/_array_api.py +++ b/sklearn/utils/_array_api.py @@ -182,6 +182,16 @@ def supported_float_dtypes(xp): return (xp.float64, xp.float32) +def max_precision_float_dtype(xp, device): + """Highest precision float dtype support by namespace and device""" + # temporary hack while waiting for a proper inspection API, see: + # https://github.com/data-apis/array-api/issues/640 + if xp.__name__ in {"array_api_compat.torch", "torch"} and device.type == "mps": + return xp.float32 + else: + return xp.float64 + + class _ArrayAPIWrapper: """sklearn specific Array API compatibility wrapper @@ -468,17 +478,14 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None): return float(numpy.average(sample_score_np, weights=sample_weight_np)) if not xp.isdtype(sample_score.dtype, "real floating"): - # The MPS device does not support float64 - if ( - xp.__name__ in {"array_api_compat.torch", "torch"} - and device(sample_score).type == "mps" - ): - sample_score = xp.astype(sample_score, xp.float32) - else: - sample_score = xp.astype(sample_score, xp.float64) + sample_score = xp.astype( + sample_score, max_precision_float_dtype(xp, device(sample_score)) + ) if sample_weight is not None: - sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype) + sample_weight = xp.asarray( + sample_weight, dtype=sample_score.dtype, device=device(sample_score) + ) if normalize: if sample_weight is not None: