-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
ENH: Implement take_along_axis as described in #8708 #8714
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
This is the reduced version that does not allow any insertion of extra dimensions
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,10 +16,222 @@ | |
__all__ = [ | ||
'column_stack', 'row_stack', 'dstack', 'array_split', 'split', | ||
'hsplit', 'vsplit', 'dsplit', 'apply_over_axes', 'expand_dims', | ||
'apply_along_axis', 'kron', 'tile', 'get_array_wrap' | ||
'apply_along_axis', 'kron', 'tile', 'get_array_wrap', 'take_along_axis', | ||
'put_along_axis' | ||
] | ||
|
||
|
||
def _make_along_axis_idx(arr, indices, axis): | ||
# compute dimensions to iterate over | ||
if arr.ndim != indices.ndim: | ||
raise ValueError( | ||
"`indices` and `arr` must have the same number of dimensions") | ||
shape_ones = (1,) * indices.ndim | ||
dest_dims = list(range(axis)) + [None] + list(range(axis+1, indices.ndim)) | ||
|
||
# build a fancy index, consisting of orthogonal aranges, with the | ||
# requested index inserted at the right location | ||
fancy_index = [] | ||
for dim, n in zip(dest_dims, arr.shape): | ||
if dim is None: | ||
fancy_index.append(indices) | ||
else: | ||
ind_shape = shape_ones[:dim] + (-1,) + shape_ones[dim+1:] | ||
fancy_index.append(_nx.arange(n).reshape(ind_shape)) | ||
|
||
return tuple(fancy_index) | ||
|
||
|
||
def take_along_axis(arr, indices, axis): | ||
""" | ||
Take the elements described by `indices` along each 1-D slice of the given | ||
`axis`, matching up subspaces of arr and indices. | ||
|
||
This function can be used to index with the result of `argsort`, `argmax`, | ||
and other `arg` functions. | ||
|
||
This is equivalent to (but faster than) the following use of `ndindex` and | ||
`s_`, which sets each of ``ii`` and ``kk`` to a tuple of indices:: | ||
|
||
Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
J = indices.shape[axis] | ||
|
||
for ii in ndindex(Ni): | ||
for j in range(J): | ||
for kk in ndindex(Nk): | ||
a_1d = a[ii + s_[j,] + kk] | ||
out[ii + s_[j,] + kk] = a_1d[indices[ii + s_[j,] + kk]] | ||
|
||
Equivalently, eliminating the inner loop, this can be expressed as:: | ||
|
||
Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
J = indices.shape[axis] | ||
|
||
for ii in ndindex(Ni): | ||
for kk in ndindex(Nk): | ||
a_1d = a[ii + s_[:,] + kk] | ||
out[ii + s_[:,] + kk] = a_1d[indices[ii + s_[:,] + kk]] | ||
|
||
.. versionadded:: 1.15.0 | ||
|
||
Parameters | ||
---------- | ||
arr: array_like (Ni..., M, Nk...) | ||
source array | ||
indices: array_like (Ni..., J, Nk...) | ||
indices to take along each 1d slice of `arr` | ||
axis: int | ||
the axis to take 1d slices along | ||
|
||
Returns | ||
------- | ||
out: ndarray (Ni..., J, Nk...) | ||
The indexed result, as described above. | ||
|
||
See Also | ||
-------- | ||
take : Take along an axis without matching up subspaces | ||
|
||
Examples | ||
-------- | ||
|
||
For this sample array | ||
|
||
>>> a = np.array([[10, 30, 20], [60, 40, 50]]) | ||
|
||
We can sort either by using sort directly, or argsort and this function | ||
|
||
>>> np.sort(a, axis=1) | ||
array([[10, 20, 30], | ||
[40, 50, 60]]) | ||
>>> ai = np.argsort(a, axis=1); ai | ||
array([[0, 2, 1], | ||
[1, 2, 0]], dtype=int64) | ||
>>> np.take_along_axis(a, ai, axis=1) | ||
array([[10, 20, 30], | ||
[40, 50, 60]]) | ||
|
||
The same works for max and min, if you expand the dimensions: | ||
|
||
>>> np.expand_dims(np.max(a, axis=1), axis=1) | ||
array([[30], | ||
[60]]) | ||
>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1) | ||
>>> ai | ||
array([[1], | ||
[0], dtype=int64) | ||
>>> np.take_along_axis(a, ai, axis=1) | ||
array([[30], | ||
[60]]) | ||
|
||
If we want to get the max and min at the same time, we can stack the | ||
indices first | ||
|
||
>>> ai_min = np.expand_dims(np.argmin(a, axis=1), axis=1) | ||
>>> ai_max = np.expand_dims(np.argmax(a, axis=1), axis=1) | ||
>>> ai = np.concatenate([ai_min, ai_max], axis=axis) | ||
>> ai | ||
array([[0, 1], | ||
[1, 0]], dtype=int64) | ||
>>> np.take_along_axis(a, ai, axis=1) | ||
array([[10, 30], | ||
[40, 60]]) | ||
""" | ||
# normalize inputs | ||
arr = asanyarray(arr) | ||
indices = asanyarray(indices) | ||
if axis is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add support for arr.ndim > indices.ndim?
Use case / example:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think your use-case is well-motivated. A more explicit way to achieve that would be: y = x.sum(axis=0, keepdims=True)
take_along_axis(x, np.argsort(y, axis=1), axis=1) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole point is that with the one-liner addition f(x) can add or remove axes at will (some manual broadcasting required if it replaces or transposes axes, which however happens automatically if you e.g. wrap this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note that in the comments above, we decide that perhaps it's best to not allow any case other than
|
||
arr = arr.ravel() | ||
axis = 0 | ||
else: | ||
axis = normalize_axis_index(axis, arr.ndim) | ||
if not _nx.issubdtype(indices.dtype, _nx.integer): | ||
raise IndexError('arrays used as indices must be of integer type') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just copying There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bools are not considered ints here right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, should have tried with arrays; not very logical but best to stick with numpy practice here. |
||
|
||
# use the fancy index | ||
return arr[_make_along_axis_idx(arr, indices, axis)] | ||
|
||
|
||
def put_along_axis(arr, indices, values, axis): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default to axis=-1 like 99% of other numpy functions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wish that were true. |
||
""" | ||
Put `values` at the elements described by `indices` along each 1-D slice of | ||
the given `axis` of `arr`, matching up subspaces of arr and indices. | ||
|
||
This function can be used to index with the result of `argsort`, `argmax`, | ||
and other `arg` functions. | ||
|
||
This is equivalent to (but faster than) the following use of `ndindex` and | ||
`s_`, which sets each of ``ii`` and ``kk`` to a tuple of indices:: | ||
|
||
# extract the subshapes as labelled in the docs below | ||
Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
J = indices.shape[axis] | ||
|
||
for ii in ndindex(Ni): | ||
for j in range(J): | ||
for kk in ndindex(Nk): | ||
a_1d = a[ii + s_[j,] + kk] | ||
a_1d[indices[ii + s_[j,] + kk]] = values[ii + s_[j,] + kk] | ||
|
||
Equivalently, eliminating the inner loop, this can be expressed as:: | ||
|
||
Ni, Nk = a.shape[:axis], a.shape[axis+1:] | ||
for ii in ndindex(Ni): | ||
for kk in ndindex(Nk): | ||
a_1d = a[ii + s_[:,] + kk] | ||
a_1d[indices[ii + s_[:,] + kk]] = values[ii + s_[:,] + kk] | ||
|
||
.. versionadded:: 1.15.0 | ||
|
||
Parameters | ||
---------- | ||
arr: array_like (Ni..., M, Nk...) | ||
source array | ||
indices: array_like (Ni..., J, Nk...) | ||
indices to change along each 1d slice of `arr` | ||
values: array_like (Ni..., J, Nk...) | ||
values to insert at those indices. Its shape and dimension are | ||
broadcast to match that of `indices`. | ||
axis: int | ||
the axis to take 1d slices along | ||
|
||
See Also | ||
-------- | ||
take_along_axis : Take along an axis without matching up subspaces | ||
|
||
Examples | ||
-------- | ||
|
||
For this sample array | ||
|
||
>>> a = np.array([[10, 30, 20], [60, 40, 50]]) | ||
|
||
We can replace the maximum values with: | ||
|
||
>>> ai = np.expand_dims(np.argmax(a, axis=1), axis=1) | ||
>>> ai | ||
array([[1], | ||
[0]], dtype=int64) | ||
>>> np.put_along_axis(a, ai, 99, axis=1) | ||
>>> a | ||
array([[10, 99, 20], | ||
[99, 40, 50]]) | ||
|
||
""" | ||
# normalize inputs | ||
indices = asanyarray(indices) | ||
if axis is None: | ||
arr = arr.ravel() | ||
axis = 0 | ||
else: | ||
axis = normalize_axis_index(axis, arr.ndim) | ||
if not _nx.issubdtype(indices.dtype, _nx.integer): | ||
raise IndexError('arrays used as indices must be of integer type') | ||
|
||
# use the fancy index | ||
arr[_make_along_axis_idx(arr, indices, axis)] = values | ||
|
||
|
||
def apply_along_axis(func1d, axis, arr, *args, **kwargs): | ||
""" | ||
Apply a function to 1-D slices along the given axis. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
default to axis=-1 like 99% of other numpy functions?