8000 GitHub · Where software is built
[go: up one dir, main page]

Skip to content
Deprecate non-tuple nd-indices #53458
@donno2048

Description

@donno2048

This is deprecated in NumPy (numpy/numpy#9686) and in Jax already removed (jax-ml/jax#4867) but TensorFlow still uses this feature as can be seen by using NumPy arrays and Jax arrays:

>>> from tensorflow.keras import Sequential
>>> from jax.numpy import array
>>> Sequential().predict(array([0]))
[...]
TypeError: Using a non-tuple sequence for multidimensional indexing is not allowed; use `arr[array(seq)]` instead of `arr[seq]`. See https://github.com/google/jax/issues/4564 for more information.
[...]
>>> from numpy import array
>>> Sequential().predict(array([0]))
array([0])

For now, I opened an issue on Jax (jax-ml/jax#8980) and asked to revert this change, but anyway I think TensorFlow really should accept the deprecation.

Metadata

Metadata

Assignees

Labels

type:othersissues not falling in bug, perfromance, support, build and install or feature

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0