From 612caca61f82e3e20117ec5917d71fd0f48c42ea Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sat, 19 Mar 2016 18:50:55 -0700 Subject: [PATCH] ENH: Allow rolling multiple axes at the same time. A quick test suggests that this implementation from @seberg, relying on slices rather than index arrays, is 1.5~3x faster than the previous (1D) roll (depending on the axis). Also switched the error message for invalid inputs to match the one of ufuncs, because the axis can actually also be negative. --- doc/release/1.12.0-notes.rst | 4 ++ numpy/core/numeric.py | 68 +++++++++++++++++++++----------- numpy/core/tests/test_numeric.py | 36 +++++++++++++++++ 3 files changed, 86 insertions(+), 22 deletions(-) diff --git a/doc/release/1.12.0-notes.rst b/doc/release/1.12.0-notes.rst index 058bdaac7079..4867594945a6 100644 --- a/doc/release/1.12.0-notes.rst +++ b/doc/release/1.12.0-notes.rst @@ -159,6 +159,10 @@ Generalized Ufuncs will now unlock the GIL Generalized Ufuncs, including most of the linalg module, will now unlock the Python global interpreter lock. +np.roll can now roll multiple axes at the same time +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +The ``shift`` and ``axis`` arguments to ``roll`` are now broadcast against each +other, and each specified axis is shifted accordingly. The *__complex__* method has been implemented on the ndarray object ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 7f12068553a3..11a95fa7b876 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -1,9 +1,11 @@ from __future__ import division, absolute_import, print_function -import sys +import collections +import itertools import operator +import sys import warnings -import collections + from numpy.core import multiarray from . import umath from .umath import (invert, sin, UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, @@ -1340,11 +1342,15 @@ def roll(a, shift, axis=None): ---------- a : array_like Input array. - shift : int - The number of places by which elements are shifted. - axis : int, optional - The axis along which elements are shifted. By default, the array - is flattened before shifting, after which the original + shift : int or tuple of ints + The number of places by which elements are shifted. If a tuple, + then `axis` must be a tuple of the same size, and each of the + given axes is shifted by the corresponding number. If an int + while `axis` is a tuple of ints, then the same value is used for + all given axes. + axis : int or tuple of ints, optional + Axis or axes along which elements are shifted. By default, the + array is flattened before shifting, after which the original shape is restored. Returns @@ -1357,6 +1363,12 @@ def roll(a, shift, axis=None): rollaxis : Roll the specified axis backwards, until it lies in a given position. + Notes + ----- + .. versionadded:: 1.12.0 + + Supports rolling over multiple dimensions simultaneously. + Examples -------- >>> x = np.arange(10) @@ -1380,22 +1392,34 @@ def roll(a, shift, axis=None): """ a = asanyarray(a) if axis is None: - n = a.size - reshape = True + return roll(a.ravel(), shift, 0).reshape(a.shape) + else: - try: - n = a.shape[axis] - except IndexError: - raise ValueError('axis must be >= 0 and < %d' % a.ndim) - reshape = False - if n == 0: - return a - shift %= n - indexes = concatenate((arange(n - shift, n), arange(n - shift))) - res = a.take(indexes, axis) - if reshape: - res = res.reshape(a.shape) - return res + broadcasted = broadcast(shift, axis) + if len(broadcasted.shape) > 1: + raise ValueError( + "'shift' and 'axis' should be scalars or 1D sequences") + shifts = {ax: 0 for ax in range(a.ndim)} + for sh, ax in broadcasted: + if -a.ndim <= ax < a.ndim: + shifts[ax % a.ndim] += sh + else: + raise ValueError("'axis' entry is out of bounds") + + rolls = [((slice(None), slice(None)),)] * a.ndim + for ax, offset in shifts.items(): + offset %= a.shape[ax] or 1 # If `a` is empty, nothing matters. + if offset: + # (original, result), (original, result) + rolls[ax] = ((slice(None, -offset), slice(offset, None)), + (slice(-offset, None), slice(None, offset))) + + result = empty_like(a) + for indices in itertools.product(*rolls): + arr_index, res_index = zip(*indices) + result[res_index] = a[arr_index] + + return result def rollaxis(a, axis, start=0): diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py index 0040f3a25428..dd9c83b25a4d 100644 --- a/numpy/core/tests/test_numeric.py +++ b/numpy/core/tests/test_numeric.py @@ -2145,6 +2145,42 @@ def test_roll2d(self): x2r = np.roll(x2, 1, axis=1) assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]])) + # Roll multiple axes at once. + x2r = np.roll(x2, 1, axis=(0, 1)) + assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]])) + + x2r = np.roll(x2, (1, 0), axis=(0, 1)) + assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]])) + + x2r = np.roll(x2, (-1, 0), axis=(0, 1)) + assert_equal(x2r, np.array([[5, 6, 7, 8, 9], [0, 1, 2, 3, 4]])) + + x2r = np.roll(x2, (0, 1), axis=(0, 1)) + assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]])) + + x2r = np.roll(x2, (0, -1), axis=(0, 1)) + assert_equal(x2r, np.array([[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]])) + + x2r = np.roll(x2, (1, 1), axis=(0, 1)) + assert_equal(x2r, np.array([[9, 5, 6, 7, 8], [4, 0, 1, 2, 3]])) + + x2r = np.roll(x2, (-1, -1), axis=(0, 1)) + assert_equal(x2r, np.array([[6, 7, 8, 9, 5], [1, 2, 3, 4, 0]])) + + # Roll the same axis multiple times. + x2r = np.roll(x2, 1, axis=(0, 0)) + assert_equal(x2r, np.array([[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]])) + + x2r = np.roll(x2, 1, axis=(1, 1)) + assert_equal(x2r, np.array([[3, 4, 0, 1, 2], [8, 9, 5, 6, 7]])) + + # Roll more than one turn in either direction. + x2r = np.roll(x2, 6, axis=1) + assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]])) + + x2r = np.roll(x2, -4, axis=1) + assert_equal(x2r, np.array([[4, 0, 1, 2, 3], [9, 5, 6, 7, 8]])) + def test_roll_empty(self): x = np.array([]) assert_equal(np.roll(x, 1), np.array([]))