18
18
import numpy .matrixlib as matrix
19
19
from function_base import diff
20
20
from numpy .lib ._compiled_base import ravel_multi_index , unravel_index
21
- from numpy .lib .stride_tricks import as_strided
22
21
makemat = matrix .matrix
23
22
24
23
def ix_ (* args ):
@@ -532,20 +531,37 @@ class ndindex(object):
532
531
(2, 1, 0)
533
532
534
533
"""
535
- def __init__ (self , * shape ):
536
- x = as_strided (_nx .zeros (1 ), shape = shape , strides = _nx .zeros_like (shape ))
537
- self ._it = _nx .nditer (x , flags = ['multi_index' ], order = 'C' )
538
534
539
- def __iter__ (self ):
540
- return self
535
+ def __init__ (self , * args ):
536
+ if len (args ) == 1 and isinstance (args [0 ], tuple ):
537
+ args = args [0 ]
538
+ self .nd = len (args )
539
+ self .ind = [0 ]* self .nd
540
+ self .index = 0
541
+ self .maxvals = args
542
+ tot = 1
543
+ for k in range (self .nd ):
544
+ tot *= args [k ]
545
+ self .total = tot
546
+
547
+ def _incrementone (self , axis ):
548
+ if (axis < 0 ): # base case
549
+ return
550
+ if (self .ind [axis ] < self .maxvals [axis ]- 1 ):
551
+ self .ind [axis ] += 1
552
+ else :
553
+ self .ind [axis ] = 0
554
+ self ._incrementone (axis - 1 )
541
555
542
556
def ndincr (self ):
543
557
"""
544
558
Increment the multi-dimensional index by one.
545
559
546
- This method is for backward compatibility only: do not use.
560
+ `ndincr` takes care of the "wrapping around" of the axes.
561
+ It is called by `ndindex.next` and not normally used directly.
562
+
547
563
"""
548
- self .next ( )
564
+ self ._incrementone ( self . nd - 1 )
549
565
550
566
def next (self ):
551
567
"""
@@ -557,8 +573,17 @@ def next(self):
557
573
Returns a tuple containing the indices of the current iteration.
558
574
559
575
"""
560
- self ._it .next ()
561
- return self ._it .multi_index
576
+ if (self .index >= self .total ):
577
+ raise StopIteration
578
+ val = tuple (self .ind )
579
+ self .index += 1
580
+ self .ndincr ()
581
+ return val
582
+
583
+ def __iter__ (self ):
584
+ return self
585
+
586
+
562
587
563
588
564
589
# You can do all this with slice() plus a few special objects,
0 commit comments