10000 FIX always return None as device when array API dispatch is disabled … · jeremiedbb/scikit-learn@ea121c3 · GitHub
[go: up one dir, main page]

Skip to content

Commit ea121c3

Browse files
ogriseljeremiedbb
authored andcommitted
FIX always return None as device when array API dispatch is disabled (scikit-learn#29119)
1 parent 62124b6 commit ea121c3

File tree

3 files changed

+44
-5
lines changed

3 files changed

+44
-5
lines changed

doc/whats_new/v1.5.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,20 @@ Version 1.5.1
2323
Changelog
2424
---------
2525

26+
:mod:`sklearn.metrics`
27+
......................
28+
29+
- |Fix| Fix a regression in :func:`metrics.r2_score`. Passing torch CPU tensors
30+
with array API dispatched disabled would complain about non-CPU devices
31+
instead of implicitly converting those inputs as regular NumPy arrays.
32+
:pr:`29119` by :user:`Olivier Grisel`.
33+
2634
:mod:`sklearn.model_selection`
2735
..............................
2836

2937
- |Fix| Fix a regression in :class:`model_selection.GridSearchCV` for parameter
3038
grids that have heterogeneous parameter values.
31-
:pr:`29078` by :user:`Loïc Estève <lesteve>`
39+
:pr:`29078` by :user:`Loïc Estève <lesteve>`.
3240

3341

3442
.. _changes_1_5:

sklearn/utils/_array_api.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,15 @@ def get_namespace_and_device(*array_list, remove_none=True, remove_types=(str,))
560560

561561
skip_remove_kwargs = dict(remove_none=False, remove_types=[])
562562

563-
return (
564-
*get_namespace(*array_list, **skip_remove_kwargs),
565-
device(*array_list, **skip_remove_kwargs),
566-
)
563+
xp, is_array_api = get_namespace(*array_list, **skip_remove_kwargs)
564+
if is_array_api:
565+
return (
566+
xp,
567+
is_array_api,
568+
device(*array_list, **skip_remove_kwargs),
569+
)
570+
else:
571+
return xp, False, None
567572

568573

569574
def _expit(X, xp=None):

sklearn/utils/tests/test_array_api.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_ravel,
2222
device,
2323
get_namespace,
24+
get_namespace_and_device,
2425
indexing_dtype,
2526
supported_float_dtypes,
2627
yield_namespace_device_dtype_combinations,
@@ -504,3 +505,28 @@ def test_indexing_dtype(namespace, _device, _dtype):
504505
assert indexing_dtype(xp) == xp.int32
505506
else:
506507
assert indexing_dtype(xp) == xp.int64
508+
509+
510+
def test_get_namespace_and_device():
511+
# Use torch as a library with custom Device objects:
512+
torch = pytest.importorskip("torch")
513+
xp_torch = pytest.importorskip("array_api_compat.torch")
514+
some_torch_tensor = torch.arange(3, device="cpu")
515+
some_numpy_array = numpy.arange(3)
516+
517+
# When dispatch is disabled, get_namespace_and_device should return the
518+
# default NumPy wrapper namespace and no device. Our code will handle such
519+
# inputs via the usual __array__ interface without attempting to dispatch
520+
# via the array API.
521+
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
522+
assert namespace is get_namespace(some_numpy_array)[0]
523+
assert not is_array_api
524+
assert device is None
525+
526+
# Otherwise, expose the torch namespace and device via array API compat
527+
# wrapper.
528+
with config_context(array_api_dispatch=True):
529+
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
530+
assert namespace is xp_torch
531+
assert is_array_api
532+
assert device == some_torch_tensor.device

0 commit comments

Comments
 (0)
0