-
-
Notifications
You must be signed in to change notification settings - Fork 12k
Description
Describe the issue:
I'm not sure if this is strictly speaking a bug. I am working on a project where we are trying to adopt the Array API. I encountered a bug which confused me for a while (glass-dev/glass#997), but when stepping through the code I realised that numpy.apply_along_axis was forcing the data type of my array-api-strict array to be NumPy.
Reproduce the code example:
# Example taken from the docs
import numpy as np
import jax.numpy as jnp
def my_func(a):
"""Average first and last element of a 1-D array"""
return (a[0] + a[-1]) * 0.5
b = np.array([[1,2,3], [4,5,6], [7,8,9]])
c = jnp.array([[1,2,3], [4,5,6], [7,8,9]])
# changes type
assert type(np.apply_along_axis(my_func, 0, b)) == np.ndarray
assert type(np.apply_along_axis(my_func, 0, c)) == np.ndarray
# preserves type
assert type(np.sum(b, axis=0)) == np.ndarray # numpy.ndarray
assert type(np.sum(c, axis=0)) != np.ndarray # jaxlib._jax.ArrayImplError message:
Python and NumPy Versions:
2.4.0
3.14.0 (main, Oct 10 2025, 12:54:13) [Clang 20.1.4 ]
Runtime Environment:
[{'numpy_version': '2.4.0',
'python': '3.14.0 (main, Oct 10 2025, 12:54:13) [Clang 20.1.4 ]',
'uname': uname_result(system='Darwin', node='MacBookPro.mynet', release='25.2.0', version='Darwin Kernel Version 25.2.0: Tue Nov 18 21:09:55 PST 2025; root:xnu-12377.61.12~1/RELEASE_ARM64_T8103', machine='arm64')},
{'simd_extensions': {'baseline': ['NEON', 'NEON_FP16', 'NEON_VFPV4', 'ASIMD'],
'found': ['ASIMDHP', 'ASIMDDP'],
'not_found': ['ASIMDFHM']}},
{'ignore_floating_point_errors_in_matmul': True}]
How does this issue affect you or how did you find it:
The bug caused some confusion, with the type being changed outside our control. This was confused further as we were using functools.partial which hid away the problem even more. Ended up "solving" it by changing our function to accept a function + arguments as separate inputs.