8000 Merge pull request #463 from certik/1.7.x_fix · numpy/numpy@7ae2fb0 · GitHub
[go: up one dir, main page]

Skip to content 8000

Commit 7ae2fb0

Browse files
committed
Merge pull request #463 from certik/1.7.x_fix
1.7.x fix: revert patches
2 parents 318a531 + 8a54c70 commit 7ae2fb0

File tree

2 files changed

+36
-17
lines changed

2 files changed

+36
-17
lines changed

numpy/lib/index_tricks.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy.matrixlib as matrix
1919
from function_base import diff
2020
from numpy.lib._compiled_base import ravel_multi_index, unravel_index
21-
from numpy.lib.stride_tricks import as_strided
2221
makemat = matrix.matrix
2322

2423
def ix_(*args):
@@ -532,20 +531,37 @@ class ndindex(object):
532531
(2, 1, 0)
533532
534533
"""
535-
def __init__(self, *shape):
536-
x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
537-
self._it = _nx.nditer(x, flags=['multi_index'], order='C')
538534

539-
def __iter__(self):
540-
return self
535+
def __init__(self, *args):
536+
if len(args) == 1 and isinstance(args[0], tuple):
537+
args = args[0]
538+
self.nd = len(args)
539+
self.ind = [0]*self.nd
540+
self.index = 0
541+
self.maxvals = args
542+
tot = 1
543+
for k in range(self.nd):
544+
tot *= args[k]
545+
self.total = tot
546+
547+
def _incrementone(self, axis):
548+
if (axis < 0): # base case
549+
return
550+
if (self.ind[axis] < self.maxvals[axis]-1):
551+
self.ind[axis] += 1
552+
else:
553+
self.ind[axis] = 0
554+
self._incrementone(axis-1)
541555

542556
def ndincr(self):
543557
"""
544558
Increment the multi-dimensional index by one.
545559
546-
This method is for backward compatibility only: do not use.
560+
`ndincr` takes care of the "wrapping around" of the axes.
561+
It is called by `ndindex.next` and not normally used directly.
562+
547563
"""
548-
self.next()
564+
self._incrementone(self.nd-1)
549565

550566
def next(self):
551567
"""
@@ -557,8 +573,17 @@ def next(self):
557573
Returns a tuple containing the indices of the current iteration.
558574
559575
"""
560-
self._it.next()
561-
return self._it.multi_index
576+
if (self.index >= self.total):
577+
raise StopIteration
578+
val = tuple(self.ind)
579+
self.index += 1
580+
self.ndincr()
581+
return val
582+
583+
def __iter__(self):
584+
return self
585+
586+
562587

563588

564589
# You can do all this with slice() plus a few special objects,

numpy/lib/tests/test_index_tricks.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from numpy import ( array, ones, r_, mgrid, unravel_index, zeros, where,
44
ndenumerate, fill_diagonal, diag_indices,
5-
diag_indices_from, s_, index_exp, ndindex )
5+
diag_indices_from, s_, index_exp )
66

77
class TestRavelUnravelIndex(TestCase):
88
def test_basic(self):
@@ -237,11 +237,5 @@ def test_diag_indices_from():
237237
assert_array_equal(c, np.arange(4))
238238

239239

240-
def test_ndindex():
241-
x = list(np.ndindex(1, 2, 3))
242-
expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))]
243-
assert_array_equal(x, expected)
244-
245-
246240
if __name__ == "__main__":
247241
run_module_suite()

0 commit comments

Comments
 (0)
0