8000 Allow rolling multiple axes at the same time. by anntzer · Pull Request #7438 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

Allow rolling multiple axes at the same time. #7438

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/release/1.12.0-notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ Generalized Ufuncs will now unlock the GIL
Generalized Ufuncs, including most of the linalg module, will now unlock
the Python global interpreter lock.

np.roll can now roll multiple axes at the same time
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The ``shift`` and ``axis`` arguments to ``roll`` are now broadcast against each
other, and each specified axis is shifted accordingly.

The *__complex__* method has been implemented on the ndarray object
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
68 changes: 46 additions & 22 deletions numpy/core/numeric.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import division, absolute_import, print_function

import sys
import collections
import itertools
import operator
import sys
import warnings
import collections

from numpy.core import multiarray
from . import umath
from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
Expand Down Expand Up @@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None):
----------
a : array_like
Input array.
shift : int
The number of places by which elements are shifted.
axis : int, optional
The axis along which elements are shifted. By default, the array
is flattened before shifting, after which the original
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a ..versionadded:: thingy here for the sequence of ints (also slightly prefer sequence over tuple).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it should really be "array-like of ints"? I kept tuple mostly for consistency with e.g. np.sum.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, OK, no I prefer tuple or sequence of ints. Thought we used sequence mostly, but if we use tuple as well, just keep it as it is. It currently is more an "array-like of ints" in implementation, but I don't like that too much to be honest (you could put in a 2x2 array...).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So to be clear. If you don't mind just ignoring the broadcasting for now, I think I would prefer something like:

try:
     axes = tuple(axis)
except:
    axes = (axis,)

or the inverted try: operator.index(axis). But if the broadcasting is important to you, maybe as is, is just simpler, hmmm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I did not know that iterating over a broadcast object would result in nd-iteration (i.e. a 2x2 axis would get flattened out at that point).
I still like the broadcasting behavior, though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could just check broadcast(...).shape to be 0 or 1D and raise an error otherwise, that would seem acceptable to me. Or multiply the length of the tuples manually if they are length 1, but probably that gets annoying.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coming back to the tuple-vs-sequence issue, I just noticed that sum, for example, would error out if a non-tuple sequence is passed as the axis argument. Not sure if that means sum should be improved or we should only accept tuples to avoid weird edge cases...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does not matter. In most of the python side stuff, we will allow any sequence and sometimes iterable (tuple(iterable) works). I don't think we have to strife for perfect consistency when it comes to tuples vs. sequences vs. iterables for the this type of arguments.

If you prefer, you can make it a strict tuple here as well, but basically I would pick whatever is easiest.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we got side tracked here and forgot about the versionadded tag. It would be good to mention somewhere that roll along multiple axes was added with this tag. IIRC it could either go here, or also into the Notes section.

while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.

Returns
Expand All @@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None):
rollaxis : Roll the specified axis backwards, until it lies in a
given position.

Notes
-----
.. versionadded:: 1.12.0

Supports rolling over multiple dimensions simultaneously.

Examples
--------
>>> x = np.arange(10)
Expand All @@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None):
"""
a = asanyarray(a)
if axis is None:
n = a.size
reshape = True
return roll(a.ravel(), shift, 0).reshape(a.shape)

else:
try:
n = a.shape[axis]
except IndexError:
raise ValueError('axis must be >= 0 and < %d' % a.ndim)
reshape = False
if n == 0:
return a
shift %= n
indexes = concatenate((arange(n - shift, n), arange(n - shift)))
res = a.take(indexes, axis)
if reshape:
res = res.reshape(a.shape)
return res
broadcasted = broadcast(shift, axis)
if len(broadcasted.shape) > 1:
raise ValueError(
"'shift' and 'axis' should be scalars or 1D sequences")
shifts = {ax: 0 for ax in range(a.ndim)}
for sh, ax in broadcasted:
if -a.ndim <= ax < a.ndim:
shifts[ax % a.ndim] += sh
else:
raise ValueError("'axis' entry is out of bounds")

rolls = [((slice(None), slice(None)),)] * a.ndim
for ax, offset in shifts.items():
offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters.
if offset:
# (original, result), (original, result)
rolls[ax] = ((slice(None, -offset), slice(offset, None)),
(slice(-offset, None), slice(None, offset)))

result = empty_like(a)
for indices in itertools.product(*rolls):
arr_index, res_index = zip(*indices)
result[res_index] = a[arr_index]

return result


def rollaxis(a, axis, start=0):
Expand Down
36 changes: 36 additions & 0 deletions numpy/core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,6 +2145,42 @@ def test_roll2d(self):
x2r = np.roll(x2, 1, axis=1)
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))

# Roll multiple axes at once.
x2r = np.roll(x2, 1, axis=(0, 1))
assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]]))

x2r = np.roll(x2, (1, 0), axis=(0, 1))
assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]))

x2r = np.roll(x2, (-1, 0), axis=(0, 1))
assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]))

x2r = np.roll(x2, (0, 1), axis=(0, 1))
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))

x2r = np.roll(x2, (0, -1), axis=(0, 1))
assert_equal(x2r, np.array([[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]))

x2r = np.roll(x2, (1, 1), axis=(0, 1))
assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]]))

x2r = np.roll(x2, (-1, -1), axis=(0, 1))
assert_equal(x2r, np.array([[6, 7, 8, 9, 5], [1, 2, 3, 4, 0]]))

# Roll the same axis multiple times.
x2r = np.roll(x2, 1, axis=(0, 0))
assert_equal(x2r, np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]))

x2r = np.roll(x2, 1, axis=(1, 1))
assert_equal(x2r, np.array([[3, 4, 0, 1, 2], [8, 9, 5, 6, 7]]))

# Roll more than one turn in either direction.
x2r = np.roll(x2, 6, axis=1)
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))

x2r = np.roll(x2, -4, axis=1)
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))

def test_roll_empty(self):
x = np.array([])
assert_equal(np.roll(x, 1), np.array([]))
Expand Down
0