8000 FIX always return None as device when array API dispatch is disabled … · scikit-learn-bot/scikit-learn@78102fd · GitHub
[go: up one dir, main page]

Skip to content

Commit 78102fd

Browse files
authored
FIX always return None as device when array API dispatch is disabled (scikit-learn#29119)
1 parent 8798dfe commit 78102fd

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

doc/whats_new/v1.5.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ Version 1.5.1
2323
Changelog
2424
---------
2525

26+
:mod:`sklearn.metrics`
27+
......................
28+
29+
- |Fix| Fix a regression in :func:`metrics.r2_score`. Passing torch CPU tensors
30+
with array API dispatched disabled would complain about non-CPU devices
31+
instead of implicitly converting those inputs as regular NumPy arrays.
32+
:pr:`29119` by :user:`Olivier Grisel`.
33+
2634
:mod:`sklearn.model_selection`
2735
..............................
2836

2937
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
3038
grids that have heterogeneous parameter values.
31-
:pr:`29078` by :user:`Loïc Estève <lesteve>`
39+
:pr:`29078` by :user:`Loïc Estève <lesteve>`.
3240

3341

3442
.. _changes_1_5:

sklearn/utils/_array_api.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,10 +568,15 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,))
568568

569569
skip_remove_kwargs = dict(remove_none=False, remove_types=[])
570570

571-
return (
572-
*get_namespace(*array_list, **skip_remove_kwargs),
573-
device(*array_list, **skip_remove_kwargs),
574-
)
571+
xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs)
572+
if is_array_api:
573+
return (
574+
xp,
575+
is_array_api,
576+
device(*array_list, **skip_remove_kwargs),
577+
)
578+
else:
579+
return xp, False, None
575580

576581

577582
def _expit(X, xp=None):

sklearn/utils/tests/test_array_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
_ravel,
2323
device,
2424
get_namespace,
25+
get_namespace_and_device,
2526
indexing_dtype,
2627
supported_float_dtypes,
2728
yield_namespace_device_dtype_combinations,
@@ -540,3 +541,28 @@ def test_isin(
540541
)
541542

542543
assert_array_equal(_convert_to_numpy(result, xp=xp), expected)
544+
545+
546+
def test_get_namespace_and_device():
547+
# Use torch as a library with custom Device objects:
548+
torch = pytest.importorskip("torch")
549+
xp_torch = pytest.importorskip("array_api_compat.torch")
550+
some_torch_tensor = torch.arange(3, device="cpu")
551+
some_numpy_array = numpy.arange(3)
552+
553+
# When dispatch is disabled, get_namespace_and_device should return the
554+
# default NumPy wrapper namespace and no device. Our code will handle such
555+
# inputs via the usual __array__ interface without attempting to dispatch
556+
# via the array API.
557+
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
558+
assert namespace is get_namespace(some_numpy_array)[0]
559+
assert not is_array_api
560+
assert device is None
561+
562+
# Otherwise, expose the torch namespace and device via array API compat
563+
# wrapper.
564+
with config_context(array_api_dispatch=True):
565+
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
566+
assert namespace is xp_torch
567+
assert is_array_api
568+
assert device == some_torch_tensor.device

0 commit comments

Comments
 (0)
0