8000 ENH: Implement take_along_axis as described in #8708 by eric-wieser · Pull Request #8714 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: Implement take_along_axis as described in #8708 #8714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
ENH: Add (put|take)_along_axis as described in #8708
This is the reduced version that does not allow any insertion of extra dimensions
  • Loading branch information
eric-wieser committed May 16, 2018
commit 7a87604f08160e4ecbff0b22d9fecb8bb6b43fcf
17 changes: 17 additions & 0 deletions doc/release/1.15.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,5 +326,22 @@ Increased performance in ``random.permutation`` for multidimensional arrays
array dimensions. Previously the fast path was only used for 1-d arrays.


New ``np.take_along_axis`` and ``np.put_along_axis`` functions
--------------------------------------------------------------
When used on multidimensional arrays, ``argsort``, ``argmin``, ``argmax``, and
``argpartition`` return arrays that are difficult to use as indices.
``take_along_axis`` provides an easy way to use these indices to lookup values
within an array, so that::

np.take_along_axis(a, np.argsort(a, axis=axis), axis=axis)

is the same as::

np.sort(a, axis=axis)

``np.put_along_axis`` acts as the dual operation for writing to these indices
within an array.


Changes
=======
2 changes: 2 additions & 0 deletions doc/source/reference/routines.indexing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Indexing-like operations
:toctree: generated/

take
take_along_axis
choose
compress
diag
Expand All @@ -50,6 +51,7 @@ Inserting data into arrays

place
put
put_along_axis
putmask
fill_diagonal

Expand Down
8 changes: 7 additions & 1 deletion numpy/core/fromnumeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def take(a, indices, axis=None, out=None, mode='raise'):
--------
compress : Take elements using a boolean mask
ndarray.take : equivalent method
take_along_axis : Take elements by matching the array and the index arrays

Notes
-----
Expand Down Expand Up @@ -478,6 +479,7 @@ def put(a, ind, v, mode='raise'):
See Also
--------
putmask, place
put_along_axis : Put elements by matching the array and the index arrays

Examples
--------
Expand Down Expand Up @@ -723,7 +725,9 @@ def argpartition(a, kth, axis=-1, kind='introselect', order=None):
-------
index_array : ndarray, int
Array of indices that partition `a` along the specified axis.
In other words, ``a[index_array]`` yields a partitioned `a`.
If `a` is one-dimensional, ``a[index_array]`` yields a partitioned `a`.
More generally, ``np.take_along_axis(a, index_array, axis=a)`` always
yields the partitioned `a`, irrespective of dimensionality.

See Also
--------
Expand Down Expand Up @@ -904,6 +908,8 @@ def argsort(a, axis=-1, kind='quicksort', order=None):
index_array : ndarray, int
Array of indices that sort `a` along the specified axis.
If `a` is one-dimensional, ``a[index_array]`` yields a sorted `a`.
More generally, ``np.take_along_axis(a, index_array, axis=a)`` always
yields the sorted `a`, irrespective of dimensionality.

See Also
--------
Expand Down
214 changes: 213 additions & 1 deletion numpy/lib/shape_base.py
9E12
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,222 @@
__all__ = [
'column_stack', 'row_stack', 'dstack', 'array_split', 'split',
'hsplit', 'vsplit', 'dsplit', 'apply_over_axes', 'expand_dims',
'apply_along_axis', 'kron', 'tile', 'get_array_wrap'
'apply_along_axis', 'kron', 'tile', 'get_array_wrap', 'take_along_axis',
'put_along_axis'
]


def _make_along_axis_idx(arr, indices, axis):
# compute dimensions to iterate over
if arr.ndim != indices.ndim:
raise ValueError(
"`indices` and `arr` must have the same number of dimensions")
shape_ones = (1,) * indices.ndim
dest_dims = list(range(axis)) + [None] + list(range(axis+1, indices.ndim))

# build a fancy index, consisting of orthogonal aranges, with the
# requested index inserted at the right location
fancy_index = []
for dim, n in zip(dest_dims, arr.shape):
if dim is None:
fancy_index.append(indices)
else:
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:]
fancy_index.append(_nx.arange(n).reshape(ind_shape))

return tuple(fancy_index)


def take_along_axis(arr, indices, axis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to axis=-1 like 99% of other numpy functions?

"""
Take the elements described by `indices` along each 1-D slice of the given
`axis`, matching up subspaces of arr and indices.

This function can be used to index with the result of `argsort`, `argmax`,
and other `arg` functions.

This is equivalent to (but faster than) the following use of `ndindex` and
`s_`, which sets each of ``ii`` and ``kk`` to a tuple of indices::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
J = indices.shape[axis]

for ii in ndindex(Ni):
for j in range(J):
for kk in ndindex(Nk):
a_1d = a[ii + s_[j,] + kk]
out[ii + s_[j,] + kk] = a_1d[indices[ii + s_[j,] + kk]]

Equivalently, eliminating the inner loop, this can be expressed as::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
J = indices.shape[axis]

for ii in ndindex(Ni):
for kk in ndindex(Nk):
a_1d = a[ii + s_[:,] + kk]
out[ii + s_[:,] + kk] = a_1d[indices[ii + s_[:,] + kk]]

.. versionadded:: 1.15.0

Parameters
----------
arr: array_like (Ni..., M, Nk...)
source array
indices: array_like (Ni..., J, Nk...)
indices to take along each 1d slice of `arr`
axis: int
the axis to take 1d slices along

Returns
-------
out: ndarray (Ni..., J, Nk...)
The indexed result, as described above.

See Also
--------
take : Take along an axis without matching up subspaces

Examples
--------

For this sample array

>>> a = np.array([[10, 30, 20], [60, 40, 50]])

We can sort either by using sort directly, or argsort and this function

>>> np.sort(a, axis=1)
array([[10, 20, 30],
[40, 50, 60]])
>>> ai = np.argsort(a, axis=1); ai
array([[0, 2, 1],
[1, 2, 0]], dtype=int64)
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 20, 30],
[40, 50, 60]])

The same works for max and min, if you expand the dimensions:

>>> np.expand_dims(np.max(a, axis=1), axis=1)
array([[30],
[60]])
>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai
array([[1],
[0], dtype=int64)
>>> np.take_along_axis(a, ai, axis=1)
array([[30],
[60]])

If we want to get the max and min at the same time, we can stack the
indices first

>>> ai_min = np.expand_dims(np.argmin(a, axis=1), axis=1)
>>> ai_max = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai = np.concatenate([ai_min, ai_max], axis=axis)
>> ai
array([[0, 1],
[1, 0]], dtype=int64)
>>> np.take_along_axis(a, ai, axis=1)
array([[10, 30],
[40, 60]])
"""
# normalize inputs
arr = asanyarray(arr)
indices = asanyarray(indices)
if axis is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add support for arr.ndim > indices.ndim?
It's a single line here (+ unit test + docs):

indices = indices.reshape((1, ) * (arr.ndim - indices.ndim) + indices.shape)

Use case / example:

x = np.array([[5, 3, 2, 8, 1],
              [0, 7, 1, 3, 2]])
# Completely arbitrary y = f(x0, x1, ..., xn), embarassingly parallel along axis=-1
# Here we only have x0, but we could have more.
y = x.sum(axis=0)
# Sort the x's, moving the ones that cause the smallest y's to the left
take_along_axis(x, np.argsort(y))

Copy link
Member Author
@eric-wieser eric-wieser May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think your use-case is well-motivated. A more explicit way to achieve that would be:

y = x.sum(axis=0, keepdims=True)
take_along_axis(x, np.argsort(y, axis=1), axis=1)

Copy link
Contributor
@crusaderky crusaderky May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point is that with the one-liner addition f(x) can add or remove axes at will (some manual broadcasting required if it replaces or transposes axes, which however happens automatically if you e.g. wrap this in xarray.apply_ufunc).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that in the comments above, we decide that perhaps it's best to not allow any case other than indices.ndim == arr.ndim, since there's no obvious right choice.

take_along_axis only really makes sense if you endeavor to keep all your axes aligned. xarray can probably solve that by axis names alone, but in numpy you need to indicate that by axis position. Therefore, you can't afford to let your axes collapse, and have numpy guess which one you lost: in your case, you're advocating for it to guess the left-most one should be reinserted - but this is only the case because you did sum(axis=0).

arr = arr.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, arr.ndim)
if not _nx.issubdtype(indices.dtype, _nx.integer):
raise IndexError('arrays used as indices must be of integer type')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a TypeError, I think (at least, [1, 2][1.] raises TypeError).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just copying np.array([1, 2])[np.array(1.)] here, which gives IndexError. This is just that error message, but without the bit about booleans.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bools are not considered ints here right?

Copy link
Member Author
@eric-wieser eric-wieser Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, np.issubdtype(np.bool_, np.integer) is false. There's a test for this error in this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, should have tried with arrays; not very logical but best to stick with numpy practice here.


# use the fancy index
return arr[_make_along_axis_idx(arr, indices, axis)]


def put_along_axis(arr, indices, values, axis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default to axis=-1 like 99% of other numpy functions?

Copy link
Member Author
@eric-wieser eric-wieser May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish that were true. concatenate defaults to axis=0, and other functions default to axis=None. Since the axis is a key part of the function, it seems best just to require it.

"""
Put `values` at the elements described by `indices` along each 1-D slice of
the given `axis` of `arr`, matching up subspaces of arr and indices.

This function can be used to index with the result of `argsort`, `argmax`,
and other `arg` functions.

This is equivalent to (but faster than) the following use of `ndindex` and
`s_`, which sets each of ``ii`` and ``kk`` to a tuple of indices::

# extract the subshapes as labelled in the docs below
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
J = indices.shape[axis]

for ii in ndindex(Ni):
for j in range(J):
for kk in ndindex(Nk):
a_1d = a[ii + s_[j,] + kk]
a_1d[indices[ii + s_[j,] + kk]] = values[ii + s_[j,] + kk]

Equivalently, eliminating the inner loop, this can be expressed as::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
for ii in ndindex(Ni):
for kk in ndindex(Nk):
a_1d = a[ii + s_[:,] + kk]
a_1d[indices[ii + s_[:,] + kk]] = values[ii + s_[:,] + kk]

.. versionadded:: 1.15.0

Parameters
----------
arr: array_like (Ni..., M, Nk...)
source array
indices: array_like (Ni..., J, Nk...)
indices to change along each 1d slice of `arr`
values: array_like (Ni..., J, Nk...)
values to insert at those indices. Its shape and dimension are
broadcast to match that of `indices`.
axis: int
the axis to take 1d slices along

See Also
--------
take_along_axis : Take along an axis without matching up subspaces

Examples
--------

For this sample array

>>> a = np.array([[10, 30, 20], [60, 40, 50]])

We can replace the maximum values with:

>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1)
>>> ai
array([[1],
[0]], dtype=int64)
>>> np.put_along_axis(a, ai, 99, axis=1)
>>> a
array([[10, 99, 20],
[99, 40, 50]])

"""
# normalize inputs
indices = asanyarray(indices)
if axis is None:
arr = arr.ravel()
axis = 0
else:
axis = normalize_axis_index(axis, arr.ndim)
if not _nx.issubdtype(indices.dtype, _nx.integer):
raise IndexError('arrays used as indices must be of integer type')

# use the fancy index
arr[_make_along_axis_idx(arr, indices, axis)] = values


def apply_along_axis(func1d, axis, arr, *args, **kwargs):
"""
Apply a function to 1-D slices along the given axis.
Expand Down
64 changes: 63 additions & 1 deletion numpy/lib/tests/test_shape_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,78 @@

import numpy as np
import warnings
import functools

from numpy.lib.shape_base import (
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
vsplit, dstack, column_stack, kron, tile, expand_dims,
vsplit, dstack, column_stack, kron, tile, expand_dims, take_along_axis
)
from numpy.testing import (
assert_, assert_equal, assert_array_equal, assert_raises, assert_warns
)


def _add_keepdims(func):
""" hack in a keepdims argument into a function taking an axis """
@functools.wraps(func)
def wrapped(a, axis, **kwargs):
res = func(a, axis=axis, **kwargs)
return np.expand_dims(res, axis=axis)
return wrapped


class TestTakeAlongAxis(object):
def test_argequivalent(self):
""" Test it translates from arg<func> to <func> """
from numpy.random import rand
a = rand(3, 4, 5)

funcs = [
(np.sort, np.argsort, dict()),
(_add_keepdims(np.min), _add_keepdims(np.argmin), dict()),
(_add_keepdims(np.max), _add_keepdims(np.argmax), dict()),
(np.partition, np.argpartition, dict(kth=2)),
]

for func, argfunc, kwargs in funcs:
for axis in range(a.ndim):
a_func = func(a, axis=axis, **kwargs)
ai_func = argfunc(a, axis=axis, **kwargs)
assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))

def test_invalid(self):
""" Test it errors when indices has too few dimensions """
a = np.ones((10, 10))
ai = np.ones((10, 2), dtype=np.intp)

# sanity check
take_along_axis(a, ai, axis=1)

# not enough indices
assert_raises(ValueError, take_along_axis, a, 1, axis=1)
# bool arrays not allowed
assert_raises(IndexError, take_along_axis, a, ai.astype(bool), axis=1)
# float arrays not allowed
assert_raises(IndexError, take_along_axis, a, ai.astype(float), axis=1)
# invalid axis
assert_raises(np.AxisError, take_along_axis, a, 1, axis=10)

def test_empty(self):
""" Test everything is ok with empty results, even with inserted dims """
a = np.ones((3, 4, 5))
ai = np.ones((3, 0, 5), dtype=np.intp)

actual = take_along_axis(a, ai, axis=1)
assert_equal(actual.shape, ai.shape)

def test_broadcast(self):
""" Test that non-indexing dimensions are broadcast """
a = np.ones((3, 4, 1))
ai = np.ones((1, 2, 5), dtype=np.intp)
actual = take_along_axis(a, ai, axis=1)
assert_equal(actual.shape, (3, 2, 5))


class TestApplyAlongAxis(object):
def test_simple(self):
a = np.ones((20, 10), 'd')
Expand Down
0