diff --git a/numpy/lib/index_tricks.py b/numpy/lib/index_tricks.py index b07fde27db46..6f2aa1d02ed2 100644 --- a/numpy/lib/index_tricks.py +++ b/numpy/lib/index_tricks.py @@ -18,7 +18,6 @@ import numpy.matrixlib as matrix from function_base import diff from numpy.lib._compiled_base import ravel_multi_index, unravel_index -from numpy.lib.stride_tricks import as_strided makemat = matrix.matrix def ix_(*args): @@ -532,20 +531,37 @@ class ndindex(object): (2, 1, 0) """ - def __init__(self, *shape): - x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape)) - self._it = _nx.nditer(x, flags=['multi_index'], order='C') - def __iter__(self): - return self + def __init__(self, *args): + if len(args) == 1 and isinstance(args[0], tuple): + args = args[0] + self.nd = len(args) + self.ind = [0]*self.nd + self.index = 0 + self.maxvals = args + tot = 1 + for k in range(self.nd): + tot *= args[k] + self.total = tot + + def _incrementone(self, axis): + if (axis < 0): # base case + return + if (self.ind[axis] < self.maxvals[axis]-1): + self.ind[axis] += 1 + else: + self.ind[axis] = 0 + self._incrementone(axis-1) def ndincr(self): """ Increment the multi-dimensional index by one. - This method is for backward compatibility only: do not use. + `ndincr` takes care of the "wrapping around" of the axes. + It is called by `ndindex.next` and not normally used directly. + """ - self.next() + self._incrementone(self.nd-1) def next(self): """ @@ -557,8 +573,17 @@ def next(self): Returns a tuple containing the indices of the current iteration. """ - self._it.next() - return self._it.multi_index + if (self.index >= self.total): + raise StopIteration + val = tuple(self.ind) + self.index += 1 + self.ndincr() + return val + + def __iter__(self): + return self + + # You can do all this with slice() plus a few special objects, diff --git a/numpy/lib/tests/test_index_tricks.py b/numpy/lib/tests/test_index_tricks.py index 0ede40d5a337..beda2d1462b1 100644 --- a/numpy/lib/tests/test_index_tricks.py +++ b/numpy/lib/tests/test_index_tricks.py @@ -2,7 +2,7 @@ import numpy as np from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where, ndenumerate, fill_diagonal, diag_indices, - diag_indices_from, s_, index_exp, ndindex ) + diag_indices_from, s_, index_exp ) class TestRavelUnravelIndex(TestCase): def test_basic(self): @@ -237,11 +237,5 @@ def test_diag_indices_from(): assert_array_equal(c, np.arange(4)) -def test_ndindex(): - x = list(np.ndindex(1, 2, 3)) - expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))] - assert_array_equal(x, expected) - - if __name__ == "__main__": run_module_suite()