8000 ENH: Add a function that performs the indexing needed to map argsort to sort · Issue #8708 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content
ENH: Add a function that performs the indexing needed to map argsort to sort #8708
Closed
@eric-wieser

Description

@eric-wieser

There've been two prs recently ( #8703, #8678 ) that have needed this function, and in general it's needed for any kind of vectorized function returning indices along an axis (argsort, argpartition, argmin, argmax, count, ...)

A python implementation looks something like this:

def take_along_axis(arr, ind, axis):
    """
    ... here means a "pack" of dimensions, possibly empty

    arr: array_like of shape (A..., M, B...)
        source array
    ind: array_like of shape (A..., K..., B...)
        indices to take along each 1d slice of `arr`
    axis: int
        index of the axis with dimension M

    out: array_like of shape (A..., K..., B...)
        out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]
    """
    ind_shape = (1,) * ind.ndim
    ins_ndim = ind.ndim - (arr.ndim - 1)   #inserted dimensions

    dest_dims = list(range(axis)) + [None] + list(range(axis+ins_ndim, ind.ndim))

    # could also call np.ix_ here with some dummy arguments, then throw those results away
    inds = []
    for dim, n in zip(dest_dims, arr.shape):
        if dim is None:
            inds.append(ind)
        else:
            ind_shape_dim = ind_shape[:dim] + (-1,) + ind_shape[dim+1:]
            inds.append(np.arange(n).reshape(ind_shape_dim))

    return arr[tuple(inds)]

Which works as intended:

np.take_along_axis(a, a.argsort(axis=axis), axis=axis) == a.sort(axis=axis)
np.take_along_axis(a, a.argmin(axis=axis), axis=axis) == a.min(axis=axis)
np.take_along_axis(a, a.argmax(axis=axis), axis=axis) == a.max(axis=axis)

As far as I can tell, np.take isn't suitable for this right now, as instead it computes:

out[a..., k..., b...] = arr[a..., inds[k...], b...]

instead of

out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]

Perhaps np.take(..., broadcast=True) should b a thing that means the above?

If not, there's perhaps enough similarity with apply_along_axis to justify the similar name:

  • apply_along_axis: out[a..., k..., b...] = f(arr[a..., :, b...])
  • take_along_axis: out[a..., k..., b...] = arr[a..., inds[a..., k..., b...], b...]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0