8000 fix: `mps` device support in `entropy` (#29321) · scikit-learn/scikit-learn@a408a59 · GitHub
[go: up one dir, main page]

Skip to content

Commit a408a59

Browse files
authored
fix: mps device support in entropy (#29321)
1 parent 0dcc364 commit a408a59

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

sklearn/metrics/cluster/_supervised.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import numpy as np
1616
from scipy import sparse as sp
1717

18-
from ...utils._array_api import get_namespace
18+
from ...utils._array_api import _max_precision_float_dtype, get_namespace_and_device
1919
from ...utils._param_validation import Interval, StrOptions, validate_params
2020
from ...utils.multiclass import type_of_target
2121
from ...utils.validation import check_array, check_consistent_length
@@ -1275,12 +1275,12 @@ def entropy(labels):
12751275
-----
12761276
The logarithm used is the natural logarithm (base-e).
12771277
"""
1278-
xp, is_array_api_compliant = get_namespace(labels)
1278+
xp, is_array_api_compliant, device_ = get_namespace_and_device(labels)
12791279
labels_len = labels.shape[0] if is_array_api_compliant else len(labels)
12801280
if labels_len == 0:
12811281
return 1.0
12821282

1283-
pi = xp.astype(xp.unique_counts(labels)[1], xp.float64)
1283+
pi = xp.astype(xp.unique_counts(labels)[1], _max_precision_float_dtype(xp, device_))
12841284

12851285
# single cluster => zero entropy
12861286
if pi.size == 1:

sklearn/utils/_array_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,18 @@ def _add_to_diagonal(array, value, xp):
609609
array[i, i] += value
610610

611611

612+
def _max_precision_float_dtype(xp, device):
613+
"""Return the float dtype with the highest precision supported by the device."""
614+
# TODO: Update to use `__array_namespace__info__()` from array-api v2023.12
615+
# when/if that becomes more widespread.
616+
xp_name = xp.__name__
617+
if xp_name in {"array_api_compat.torch", "torch"} and (
618+
str(device).startswith("mps")
619+
): # pragma: no cover
620+
return xp.float32
621+
return xp.float64
622+
623+
612624
def _find_matching_floating_dtype(*arrays, xp):
613625
"""Find a suitable floating point dtype when computing with arrays.
614626

sklearn/utils/tests/test_array_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
_estimator_with_converted_arrays,
1818
_is_numpy_namespace,
1919
_isin,
20+
_max_precision_float_dtype,
2021
_nanmax,
2122
_nanmin,
2223
_NumPyAPIWrapper,
@@ -510,6 +511,15 @@ def test_indexing_dtype(namespace, _device, _dtype):
510511
assert indexing_dtype(xp) == xp.int64
511512

512513

514+
@pytest.mark.parametrize(
515+
"namespace, _device, _dtype", yield_namespace_device_dtype_combinations()
516+
)
517+
def test_max_precision_float_dtype(namespace, _device, _dtype):
518+
xp = _array_api_for_tests(namespace, _device)
519+
expected_dtype = xp.float32 if _device == "mps" else xp.float64
520+
assert _max_precision_float_dtype(xp, _device) == expected_dtype
521+
522+
513523
@pytest.mark.parametrize(
514524
"array_namespace, device, _", yield_namespace_device_dtype_combinations()
515525
)

0 commit comments

Comments
 (0)
0