8000 MAINT: Add default value for `axis` parameter in `numpy.take_along_axis` · numpy/numpy@ff92ba8 · GitHub
[go: up one dir, main page]

Skip to content

Commit ff92ba8

Browse files
committed
MAINT: Add default value for axis parameter in numpy.take_along_axis
1 parent 78d3a5b commit ff92ba8

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
* NumPy's ``__array_api_version__`` was upgraded from ``2023.12`` to ``2024.12``.
22
* `numpy.count_nonzero` for ``axis=None`` (default) now returns a NumPy scalar
33
instead of a Python integer.
4+
* The parameter ``axis`` in `numpy.take_along_axis` function has now a default
5+
value of ``-1``.

numpy/lib/_shape_base_impl.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@ def _make_along_axis_idx(arr_shape, indices, axis):
5050
return tuple(fancy_index)
5151

5252

53-
def _take_along_axis_dispatcher(arr, indices, axis):
53+
def _take_along_axis_dispatcher(arr, indices, axis=None):
5454
return (arr, indices)
5555

5656

5757
@array_function_dispatch(_take_along_axis_dispatcher)
58-
def take_along_axis(arr, indices, axis):
58+
def take_along_axis(arr, indices, axis=-1):
5959
"""
6060
Take values from the input array by matching 1d index and data slices.
6161
@@ -71,14 +71,17 @@ def take_along_axis(arr, indices, axis):
7171
arr : ndarray (Ni..., M, Nk...)
7272
Source array
7373
indices : ndarray (Ni..., J, Nk...)
74-
Indices to take along each 1d slice of `arr`. This must match the
75-
dimension of arr, but dimensions Ni and Nj only need to broadcast
76-
against `arr`.
77-
axis : int
74+
Indices to take along each 1d slice of ``arr``. This must match the
75+
dimension of ``arr``, but dimensions Ni and Nj only need to broadcast
76+
against ``arr``.
77+
axis : int or None, optional
7878
The axis to take 1d slices along. If axis is None, the input array is
7979
treated as if it had first been flattened to 1d, for consistency with
8080
`sort` and `argsort`.
8181
82+
.. versionchanged:: 2.3
83+
The default value is now ``-1``.
84+
8285
Returns
8386
-------
8487
out: ndarray (Ni..., J, Nk...)

numpy/lib/_shape_base_impl.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class _SupportsArrayWrap(Protocol):
7070
def take_along_axis(
7171
arr: _ScalarT | NDArray[_ScalarT],
7272
indices: NDArray[integer],
73-
axis: int | None,
73+
axis: int | None = ...,
7474
) -> NDArray[_ScalarT]: ...
7575

7676
def put_along_axis(

0 commit comments

Comments
 (0)
0