Closed
Description
There've been two prs recently ( #8703, #8678 ) that have needed this function, and in general it's needed for any kind of vectorized function returning indices along an axis (argsort
, argpartition
, argmin
, argmax
, count
, ...)
A python implementation looks something like this:
def take_along_axis(arr, ind, axis):
"""
... here means a "pack" of dimensions, possibly empty
arr: array_like of shape (A..., M, B...)
source array
ind: array_like of shape (A..., K..., B...)
indices to take along each 1d slice of `arr`
axis: int
index of the axis with dimension M
out: array_like of shape (A..., K..., B...)
out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]
"""
ind_shape = (1,) * ind.ndim
ins_ndim = ind.ndim - (arr.ndim - 1) #inserted dimensions
dest_dims = list(range(axis)) + [None] + list(range(axis+ins_ndim, ind.ndim))
# could also call np.ix_ here with some dummy arguments, then throw those results away
inds = []
for dim, n in zip(dest_dims, arr.shape):
if dim is None:
inds.append(ind)
else:
ind_shape_dim = ind_shape[:dim] + (-1,) + ind_shape[dim+1:]
inds.append(np.arange(n).reshape(ind_shape_dim))
return arr[tuple(inds)]
Which works as intended:
np.take_along_axis(a, a.argsort(axis=axis), axis=axis) == a.sort(axis=axis)
np.take_along_axis(a, a.argmin(axis=axis), axis=axis) == a.min(axis=axis)
np.take_along_axis(a, a.argmax(axis=axis), axis=axis) == a.max(axis=axis)
As far as I can tell, np.take
isn't suitable for this right now, as instead it computes:
out[a..., k..., b...] = arr[a..., inds[k...], b...]
instead of
out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]
Perhaps np.take(..., broadcast=True)
should b a thing that means the above?
If not, there's perhaps enough similarity with apply_along_axis
to justify the similar name:
apply_along_axis
:out[a..., k..., b...] = f(arr[a..., :, b...])
take_along_axis
:out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]