8000 ENH Remove hardcoded device choice in _weighted_sum by betatim · Pull Request #27232 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

ENH Remove hardcoded device choice in _weighted_sum #27232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,16 @@ def supported_float_dtypes(xp):
return (xp.float64, xp.float32)


def max_precision_float_dtype(xp, device):
"""Highest precision float dtype support by namespace and device"""
# temporary hack while waiting for a proper inspection API, see:
# https://github.com/data-apis/array-api/issues/640
if xp.__name__ in {"array_api_compat.torch", "torch"} and device.type == "mps":
return xp.float32
else:
return xp.float64


class _ArrayAPIWrapper:
"""sklearn specific Array API compatibility wrapper

Expand Down Expand Up @@ -453,7 +463,12 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
# with lazy Array API implementations. See:
# https://github.com/data-apis/array-api/issues/642
if xp is None:
xp, _ = get_namespace(sample_score)
# Make sure the scores and weights belong to the same namespace
if sample_weight is not None:
xp, _ = get_namespace(sample_score, sample_weight)
else:
xp, _ = get_namespace(sample_score)

if normalize and _is_numpy_namespace(xp):
sample_score_np = numpy.asarray(sample_score)
if sample_weight is not None:
Expand All @@ -463,14 +478,14 @@ def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
return float(numpy.average(sample_score_np, weights=sample_weight_np))

if not xp.isdtype(sample_score.dtype, "real floating"):
# We move to cpu device ahead of time since certain devices may not support
# float64, but we want the same precision for all devices and namespaces.
sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
sample_score = xp.astype(
sample_score, max_precision_float_dtype(xp, device(sample_score))
)

if sample_weight is not None:
sample_weight = xp.asarray(sample_weight, dtype=sample_score.dtype)
if not xp.isdtype(sample_weight.dtype, "real floating"):
sample_weight = xp.astype(sample_weight, xp.float64)
sample_weight = xp.asarray(
sample_weight, dtype=sample_score.dtype, device=device(sample_score)
)

if normalize:
if sample_weight is not None:
Expand Down
0