8000 ENH: Allow rolling multiple axes at the same time. · numpy/numpy@612caca · GitHub
[go: up one dir, main page]

Skip to content

Commit 612caca

Browse files
committed
ENH: Allow rolling multiple axes at the same time.
A quick test suggests that this implementation from @seberg, relying on slices rather than index arrays, is 1.5~3x faster than the previous (1D) roll (depending on the axis). Also switched the error message for invalid inputs to match the one of ufuncs, because the axis can actually also be negative.
1 parent 2af06c8 commit 612caca

File tree

3 files changed

+86
-22
lines changed

3 files changed

+86
-22
lines changed

doc/release/1.12.0-notes.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,10 @@ Generalized Ufuncs will now unlock the GIL
159159
Generalized Ufuncs, including most of the linalg module, will now unlock
160160
the Python global interpreter lock.
161161

162+
np.roll can now roll multiple axes at the same time
163+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
164+
The ``shift`` and ``axis`` arguments to ``roll`` are now broadcast against each
165+
other, and each specified axis is shifted accordingly.
162166

163167
The *__complex__* method has been implemented on the ndarray object
164168
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

numpy/core/numeric.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import division, absolute_import, print_function
22

3-
import sys
3+
import collections
4+
import itertools
45
import operator
6+
import sys
57
import warnings
6-
import collections
8+
79
from numpy.core import multiarray
810
from . import umath
911
from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE,
@@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None):
13401342
----------
13411343
a : array_like
13421344
Input array.
1343-
shift : int
1344-
The number of places by which elements are shifted.
1345-
axis : int, optional
1346-
The axis along which elements are shifted. By default, the array
1347-
is flattened before shifting, after which the original
1345+
shift : int or tuple of ints
1346+
The number of places by which elements are shifted. If a tuple,
1347+
then `axis` must be a tuple of the same size, and each of the
1348+
given axes is shifted by the corresponding number. If an int
1349+
while `axis` is a tuple of ints, then the same value is used for
1350+
all given axes.
1351+
axis : int or tuple of ints, optional
1352+
Axis or axes along which elements are shifted. By default, the
1353+
array is flattened before shifting, after which the original
13481354
shape is restored.
13491355
13501356
Returns
@@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None):
13571363
rollaxis : Roll the specified axis backwards, until it lies in a
13581364
given position.
13591365
1366+
Notes
1367+
-----
1368+
.. versionadded:: 1.12.0
1369+
1370+
Supports rolling over multiple dimensions simultaneously.
1371+
13601372
Examples
13611373
--------
13621374
>>> x = np.arange(10)
@@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None):
13801392
"""
13811393
a = asanyarray(a)
13821394
if axis is None:
1383-
n = a.size
1384-
reshape = True
1395+
return roll(a.ravel(), shift, 0).reshape(a.shape)
1396+
13851397
else:
1386-
try:
1387-
n = a.shape[axis]
1388-
except IndexError:
1389-
raise ValueError('axis must be >= 0 and < %d' % a.ndim)
1390-
reshape = False
1391-
if n == 0:
1392-
return a
1393-
shift %= n
1394-
indexes = concatenate((arange(n - shift, n), arange(n - shift)))
1395-
res = a.take(indexes, axis)
1396-
if reshape:
1397-
res = res.reshape(a.shape)
1398-
return res
1398+
broadcasted = broadcast(shift, axis)
1399+
if len(broadcasted.shape) > 1:
1400+
raise ValueError(
1401+
"'shift' and 'axis' should be scalars or 1D sequences")
1402+
shifts = {ax: 0 for ax in range(a.ndim)}
1403+
for sh, ax in broadcasted:
1404+
if -a.ndim <= ax < a.ndim:
1405+
shifts[ax % a.ndim] += sh
1406+
else:
1407+
raise ValueError("'axis' entry is out of bounds")
1408+
1409+
rolls = [((slice(None), slice(None)),)] * a.ndim
1410+
for ax, offset in shifts.items():
1411+
offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters.
1412+
if offset:
1413+
# (original, result), (original, result)
1414+
rolls[ax] = ((slice(None, -offset), slice(offset, None)),
1415+
(slice(-offset, None), slice(None, offset)))
1416+
1417+
result = empty_like(a)
1418+
for indices in itertools.product(*rolls):
1419+
arr_index, res_index = zip(*indices)
1420+
result[res_index] = a[arr_index]
1421+
1422+
return result
13991423

14001424

14011425
def rollaxis(a, axis, start=0):

numpy/core/tests/test_numeric.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,6 +2145,42 @@ def test_roll2d(self):
21452145
x2r = np.roll(x2, 1, axis=1)
21462146
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
21472147

2148+
# Roll multiple axes at once.
2149+
x2r = np.roll(x2, 1, axis=(0, 1))
2150+
assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]]))
2151+
2152+
x2r = np.roll(x2, (1, 0), axis=(0, 1))
2153+
assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]))
2154+
2155+
x2r = np.roll(x2, (-1, 0), axis=(0, 1))
2156+
assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]]))
2157+
2158+
x2r = np.roll(x2, (0, 1), axis=(0, 1))
2159+
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
2160+
2161+
x2r = np.roll(x2, (0, -1), axis=(0, 1))
2162+
assert_equal(x2r, np.array([[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]))
2163+
2164+
x2r = np.roll(x2, (1, 1), axis=(0, 1))
2165+
assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]]))
2166+
2167+
x2r = np.roll(x2, (-1, -1), axis=(0, 1))
2168+
assert_equal(x2r, np.array([[6, 7, 8, 9, 5], [1, 2, 3, 4, 0]]))
2169+
2170+
# Roll the same axis multiple times.
2171+
x2r = np.roll(x2, 1, axis=(0, 0))
2172+
assert_equal(x2r, np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]))
2173+
2174+
x2r = np.roll(x2, 1, axis=(1, 1))
2175+
assert_equal(x2r, np.array([[3, 4, 0, 1, 2], [8, 9, 5, 6, 7]]))
2176+
2177+
# Roll more than one turn in either direction.
2178+
x2r = np.roll(x2, 6, axis=1)
2179+
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
2180+
2181+
x2r = np.roll(x2, -4, axis=1)
2182+
assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]]))
2183+
21482184
def test_roll_empty(self):
21492185
x = np.array([])
21502186
assert_equal(np.roll(x, 1), np.array([]))

0 commit comments

Comments
 (0)
0