8000 fix yield_namespace_device_dtype_combinations · ogrisel/scikit-learn@ed1b088 · GitHub
[go: up one dir, main page]

Skip to content

Commit ed1b088

Browse files
committed
fix yield_namespace_device_dtype_combinations
1 parent 77e0642 commit ed1b088

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

sklearn/utils/_array_api.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,23 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
8585
yield array_namespace, device, dtype
8686
yield array_namespace, "mps", "float32"
8787
elif array_namespace == "jax.experimental.array_api":
88-
import jax
89-
90-
for device in jax.devices():
91-
# XXX: this will dynamically and implicitly pick-up any
92-
# non-default device if JAX is configured to use it, contrary
93-
# to PyTorch for which we explicitly list all the devices we
94-
# want to test against and then later skip in in the tests if
95-
# it is not available.
96-
yield array_namespace, device, "float32"
88+
# XXX: this will dynamically and implicitly pick-up any non-default
89+
# device if JAX is configured to use it, contrary to PyTorch for
90+
# which we explicitly list all the devices we want to test against
91+
# and then later skip in in the tests if it is not available.
92+
#
93+
# TODO: make yield_namespace_device_dtype_combinations return
94+
# device names instead and let _array_api_for_tests return the
95+
# actual device objects and skip the tests if the device is not
96+
# available.
97+
try:
98+
import jax
99+
100+
for device in jax.devices():
101+
yield array_namespace, device, "float32"
102+
except ImportError:
103+
continue
104+
97105
else:
98106
yield array_namespace, None, None
99107

0 commit comments

Comments
 (0)
0