-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Enable array API testing for jax.experimental.array_api #29647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I forgot to add jax to one of the CI builds with |
We get 41 failed array API tests. Most (all?) of them seem caused by the inplace assignment problem. Also note that 140 array API tests pass with jax though. EDIT: here is an archive of the 42 failures observed on the CPU-only run (vs 41 for the CUDA run):
I checked and the two The only other failure is an assertion error in
>>> a = xp.arange(2 ** 32 + 1, dtype=xp.uint8)
>>> xp.asarray(2 ** 32 + 1).dtype
Traceback (most recent call last):
Cell In[18], line 1
xp.asarray(2 ** 32 + 1).dtype
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3568 in asarray
return array(a, dtype=dtype, copy=bool(copy), order=order, device=device)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:3387 in array
_ = dtypes.coerce_to_array(object, dtype)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/dtypes.py:322 in coerce_to_array
dtype = _scalar_type_to_dtype(type(x), x)
File ~/miniforge3/envs/dev/lib/python3.11/site-packages/jax/_src/dtypes.py:311 in _scalar_type_to_dtype
raise OverflowError(f"Python int {value} too large to convert to {dtype}")
OverflowError: Python int 4294967297 too large to convert to int32
>>> xp.asarray([2 ** 32 + 1])
Array([1], dtype=int32) |
some progress at scipy/scipy#22070 |
For the record, #30340 was recently merged to EDIT: updating this draft PR to explore usage of |
This is a draft PR to run the existing array API tests on jax inputs. This is not expected to work as jax does not support inplace updates via
__setitem__
in particular as discussed in:.at
for simulating in-place ops data-apis/array-api#609Note that scipy started to run its array API tests against jax but maintains a list of tests to skip because of that design decision:
TODO before considering a review for merge