-
Notifications
You must be signed in to change notification settings - Fork 74.8k
Closed
Labels
type:othersissues not falling in bug, perfromance, support, build and install or featureissues not falling in bug, perfromance, support, build and install or feature
Description
This is deprecated in NumPy (numpy/numpy#9686) and in Jax already removed (jax-ml/jax#4867) but TensorFlow still uses this feature as can be seen by using NumPy arrays and Jax arrays:
>>> from tensorflow.keras import Sequential
>>> from jax.numpy import array
>>> Sequential().predict(array([0]))
[...]
TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.
[...]
>>> from numpy import array
>>> Sequential().predict(array([0]))
array([0])
For now, I opened an issue on Jax (jax-ml/jax#8980) and asked to revert this change, but anyway I think TensorFlow really should accept the deprecation.
Metadata
Metadata
Assignees
Labels
type:othersissues not falling in bug, perfromance, support, build and install or featureissues not falling in bug, perfromance, support, build and install or feature