8000 FIX Fix device detection when array API dispatch is disabled (#30454) · scikit-learn/scikit-learn@24ffd96 · GitHub
[go: up one dir, main page]

Skip to content

Commit 24ffd96

Browse files
lesteveogriselOmarManzoor
authored andcommitted
FIX Fix device detection when array API dispatch is disabled (#30454)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Omar Salman <omar.salman2007@gmail.com>
1 parent bd76910 commit 24ffd96

File tree

5 files changed

+102
-15
lines changed

5 files changed

+102
-15
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- Fix regression when scikit-learn metric called on PyTorch CPU tensors would
2+
raise an error (with array API dispatch disabled which is the default).
3+
By :user:`Loïc Estève <lesteve>`

sklearn/metrics/tests/test_common.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1817,6 +1817,40 @@ def check_array_api_metric(
18171817
if isinstance(multioutput, np.ndarray):
18181818
metric_kwargs["multioutput"] = xp.asarray(multioutput, device=device)
18191819

1820+
# When array API dispatch is disabled, and np.asarray works (for example PyTorch
1821+
# with CPU device), calling the metric function with such numpy compatible inputs
1822+
# should work (albeit by implicitly converting to numpy arrays instead of
1823+
# dispatching to the array library).
1824+
try:
1825+
np.asarray(a_xp)
1826+
np.asarray(b_xp)
1827+
numpy_as_array_works = True
1828+
except TypeError:
1829+
# PyTorch with CUDA device and CuPy raise TypeError consistently.
1830+
# Exception type may need to be updated in the future for other
1831+
# libraries.
1832+
numpy_as_array_works = False
1833+
1834+
if numpy_as_array_works:
1835+
metric_xp = metric(a_xp, b_xp, **metric_kwargs)
1836+
assert_allclose(
1837+
metric_xp,
1838+
metric_np,
1839+
atol=_atol_for_type(dtype_name),
1840+
)
1841+
metric_xp_mixed_1 = metric(a_np, b_xp, **metric_kwargs)
1842+
assert_allclose(
1843+
metric_xp_mixed_1,
1844+
metric_np,
1845+
atol=_atol_for_type(dtype_name),
1846+
)
1847+
metric_xp_mixed_2 = metric(a_xp, b_np, **metric_kwargs)
1848+
assert_allclose(
1849+
metric_xp_mixed_2,
1850+
metric_np,
1851+
atol=_atol_for_type(dtype_name),
1852+
)
1853+
18201854
with config_context(array_api_dispatch=True):
18211855
metric_xp = metric(a_xp, b_xp, **metric_kwargs)
18221856

sklearn/utils/_array_api.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,17 @@ def _check_array_api_dispatch(array_api_dispatch):
130130

131131
def _single_array_device(array):
132132
"""Hardware device where the array data resides on."""
133-
if isinstance(array, (numpy.ndarray, numpy.generic)) or not hasattr(
134-
array, "device"
133+
if (
134+
isinstance(array, (numpy.ndarray, numpy.generic))
135+
or not hasattr(array, "device")
136+
# When array API dispatch is disabled, we expect the scikit-learn code
137+
# to use np.asarray so that the resulting NumPy array will implicitly use the
138+
# CPU. In this case, scikit-learn should stay as device neutral as possible,
139+
# hence the use of `device=None` which is accepted by all libraries, before
140+
# and after the expected conversion to NumPy via np.asarray.
141+
or not get_config()["array_api_dispatch"]
135142
):
136-
return "cpu"
143+
return None
137144
else:
138145
return array.device
139146

sklearn/utils/estimator_checks.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,40 @@ def check_array_api_input(
11131113
"transform",
11141114
)
11151115

1116+
try:
1117+
np.asarray(X_xp)
1118+
np.asarray(y_xp)
1119+
# TODO There are a few errors in SearchCV with array-api-strict because
1120+
# we end up doing X[train_indices] where X is an array-api-strict array
1121+
# and train_indices is a numpy array. array-api-strict insists
1122+
# train_indices should be an array-api-strict array. On the other hand,
1123+
# all the array API libraries (PyTorch, jax, CuPy) accept indexing with a
1124+
# numpy array. This is probably not worth doing anything about for
1125+
# now since array-api-strict seems a bit too strict ...
1126+
numpy_asarray_works = xp.__name__ != "array_api_strict"
1127+
1128+
except TypeError:
1129+
# PyTorch with CUDA device and CuPy raise TypeError consistently.
1130+
# Exception type may need to be updated in the future for other
1131+
# libraries.
1132+
numpy_asarray_works = False
1133+
1134+
if numpy_asarray_works:
1135+
# In this case, array_api_dispatch is disabled and we rely on np.asarray
1136+
# being called to convert the non-NumPy inputs to NumPy arrays when needed.
1137+
est_fitted_with_as_array = clone(est).fit(X_xp, y_xp)
1138+
# We only do a smoke test for now, in order to avoid complicating the
1139+
# test function even further.
1140+
for method_name in methods:
1141+
method = getattr(est_fitted_with_as_array, method_name, None)
1142+
if method is None:
1143+
continue
1144+
1145+
if method_name == "score":
1146+
method(X_xp, y_xp)
1147+
else:
1148+
method(X_xp)
1149+
11161150
for method_name in methods:
11171151
method = getattr(est, method_name, None)
11181152
if method is None:

sklearn/utils/tests/test_array_api.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ def test_device_none_if_no_input():
248248
assert device(None, "name") is None
249249

250250

251+
@skip_if_array_api_compat_not_configured
251252
def test_device_inspection():
252253
class Device:
253254
def __init__(self, name):
@@ -273,18 +274,26 @@ def __init__(self, device_name):
273274
with pytest.raises(TypeError):
274275
hash(Array("device").device)
275276

276-
# Test raise if on different devices
277+
# If array API dispatch is disabled the device should be ignored. Erroring
278+
# early for different devices would prevent the np.asarray conversion to
279+
# happen. For example, `r2_score(np.ones(5), torch.ones(5))` should work
280+
# fine with array API disabled.
281+
assert device(Array("cpu"), Array("mygpu")) is None
282+
283+
# Test that ValueError is raised if on different devices and array API dispatch is
284+
# enabled.
277285
err_msg = "Input arrays use different devices: cpu, mygpu"
278-
with pytest.raises(ValueError, match=err_msg):
279-
device(Array("cpu"), Array("mygpu"))
286+
with config_context(array_api_dispatch=True):
287+
with pytest.raises(ValueError, match=err_msg):
288+
device(Array("cpu"), Array("mygpu"))
280289

281-
# Test expected value is returned otherwise
282-
array1 = Array("device")
283-
array2 = Array("device")
290+
# Test expected value is returned otherwise
291+
array1 = Array("device")
292+
array2 = Array("device")
284293

285-
assert array1.device == device(array1)
286-
assert array1.device == device(array1, array2)
287-
assert array1.device == device(array1, array1, array2)
294+
assert array1.device == device(array1)
295+
assert array1.device == device(array1, array2)
296+
assert array1.device == device(array1, array1, array2)
288297

289298

290299
# TODO: add cupy to the list of libraries once the the following upstream issue
@@ -553,7 +562,7 @@ def test_get_namespace_and_device():
553562
namespace, is_array_api, device = get_namespace_and_device(some_torch_tensor)
554563
assert namespace is get_namespace(some_numpy_array)[0]
555564
assert not is_array_api
556-
assert device.type == "cpu"
565+
assert device is None
557566

558567
# Otherwise, expose the torch namespace and device via array API compat
559568
# wrapper.
@@ -621,8 +630,8 @@ def test_sparse_device(csr_container, dispatch):
621630
try:
622631
with config_context(array_api_dispatch=dispatch):
623632
assert device(a, b) is None
624-
assert device(a, numpy.array([1])) == "cpu"
633+
assert device(a, numpy.array([1])) is None
625634
assert get_namespace_and_device(a, b)[2] is None
626-
assert get_namespace_and_device(a, numpy.array([1]))[2] == "cpu"
635+
assert get_namespace_and_device(a, numpy.array([1]))[2] is None
627636
except ImportError:
628637
raise SkipTest("array_api_compat is not installed")

0 commit comments

Comments
 (0)
0