8000 Merge pull request #24242 from charris/backport-24187 · numpy/numpy@0409b76 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0409b76

Browse files
authored
Merge pull request #24242 from charris/backport-24187
BUG: Fix the signature for np.array_api.take
2 parents 98b6d3c + 2fff424 commit 0409b76

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

numpy/array_api/_indexing_functions.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
import numpy as np
77

8-
def take(x: Array, indices: Array, /, *, axis: int) -> Array:
8+
def take(x: Array, indices: Array, /, *, axis: Optional[int] = None) -> Array:
99
"""
1010
Array API compatible wrapper for :py:func:`np.take <numpy.take>`.
1111
1212
See its docstring for more information.
13-
"""
13+
"""
14+
if axis is None and x.ndim != 1:
15+
raise ValueError("axis must be specified when ndim > 1")
1416
if indices.dtype not in _integer_dtypes:
1517
raise TypeError("Only integer dtypes are allowed in indexing")
16-
if indices.ndim != 1:
18+
if indices.ndim != 1:
1719
raise ValueError("Only 1-dim indices array is supported")
1820
return Array._new(np.take(x._array, indices._array, axis=axis))

0 commit comments

Comments
 (0)
0