@@ -85,15 +85,23 @@ def yield_namespace_device_dtype_combinations(include_numpy_namespaces=True):
85
85
yield array_namespace , device , dtype
86
86
yield array_namespace , "mps" , "float32"
87
87
elif array_namespace == "jax.experimental.array_api" :
88
- import jax
89
-
90
- for device in jax .devices ():
91
- # XXX: this will dynamically and implicitly pick-up any
92
- # non-default device if JAX is configured to use it, contrary
93
- # to PyTorch for which we explicitly list all the devices we
94
- # want to test against and then later skip in in the tests if
95
- # it is not available.
96
- yield array_namespace , device , "float32"
88
+ # XXX: this will dynamically and implicitly pick-up any non-default
89
+ # device if JAX is configured to use it, contrary to PyTorch for
90
+ # which we explicitly list all the devices we want to test against
91
+ # and then later skip in in the tests if it is not available.
92
+ #
93
+ # TODO: make yield_namespace_device_dtype_combinations return
94
+ # device names instead and let _array_api_for_tests return the
95
+ # actual device objects and skip the tests if the device is not
96
+ # available.
97
+ try :
98
+ import jax
99
+
100
+ for device in jax .devices ():
101
+ yield array_namespace , device , "float32"
102
+ except ImportError :
103
+ continue
104
+
97
105
else :
98
106
yield array_namespace , None , None
99
107
0 commit comments