8000 Merge pull request #10951 from mattip/nditer-close-fixes · numpy/numpy@f2888db · GitHub
[go: up one dir, main page]

Skip to content

Commit f2888db

Browse files
authored
Merge pull request #10951 from mattip/nditer-close-fixes
BUG: it.close() disallows access to iterator, fixes #10950
2 parents b5c1bcf + ac7d543 commit f2888db

File tree

4 files changed

+64
-40
lines changed

4 files changed

+64
-40
lines changed

doc/source/reference/arrays.nditer.rst

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -394,10 +394,10 @@ parameter support.
394394
.. admonition:: Example
395395

396396
>>> def square(a):
397-
... it = np.nditer([a, None])
398-
... for x, y in it:
399-
... y[...] = x*x
400-
... return it.operands[1]
397+
... with np.nditer([a, None]) as it:
398+
... for x, y in it:
399+
... y[...] = x*x
400+
... return it.operands[1]
401401
...
402402
>>> square([1,2,3])
403403
array([1, 4, 9])
@@ -490,17 +490,22 @@ Everything to do with the outer product is handled by the iterator setup.
490490
>>> b = np.arange(8).reshape(2,4)
491491
>>> it = np.nditer([a, b, None], flags=['external_loop'],
492492
... op_axes=[[0, -1, -1], [-1, 0, 1], None])
493-
>>> for x, y, z in it:
494-
... z[...] = x*y
493+
>>> with it:
494+
... for x, y, z in it:
495+
... z[...] = x*y
496+
... result = it.operands[2] # same as z
495497
...
496-
>>> it.operands[2]
498+
>>> result
497499
array([[[ 0, 0, 0, 0],
498500
[ 0, 0, 0, 0]],
499501
[[ 0, 1, 2, 3],
500502
[ 4, 5, 6, 7]],
501503
[[ 0, 2, 4, 6],
502504
[ 8, 10, 12, 14]]])
503505

506+
Note that once the iterator is closed we can not access :func:`operands <nditer.operands>`
507+
and must use a reference created inside the context manager.
508+
504509
Reduction Iteration
505510
-------------------
506511

@@ -540,8 +545,9 @@ sums along the last axis of `a`.
540545
... it.operands[1][...] = 0
541546
... for x, y in it:
542547
... y[...] += x
548+
... result = it.operands[1]
543549
...
544-
... it.operands[1]
550+
>>> result
545551
array([[ 6, 22, 38],
546552
[54, 70, 86]])
547553
>>> np.sum(a, axis=2)
@@ -575,8 +581,9 @@ buffering.
575581
... it.reset()
576582
... for x, y in it:
577583
... y[...] += x
584+
... result = it.operands[1]
578585
...
579-
... it.operands[1]
586+
>>> result
580587
array([[ 6, 22, 38],
581588
[54, 70, 86]])
582589

numpy/add_newdocs.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@
257257
dtypes : tuple of dtype(s)
258258
The data types of the values provided in `value`. This may be
259259
different from the operand data types if buffering is enabled.
260+
Valid only before the iterator is closed.
260261
finished : bool
261262
Whether the iteration over the operands is finished or not.
262263
has_delayed_bufalloc : bool
@@ -282,7 +283,8 @@
282283
Size of the iterator.
283284
itviews
284285
Structured view(s) of `operands` in memory, matching the reordered
285-
and optimized iterator access pattern.
286+
and optimized iterator access pattern. Valid only before the iterator
287+
is closed.
286288
multi_index
287289
When the "multi_index" flag was used, this property
288290
provides access to the index. Raises a ValueError if accessed
@@ -292,7 +294,8 @@
292294
nop : int
293295
The number of iterator operands.
294296
operands : tuple of operand(s)
295-
The array(s) to be iterated over.
297+
The array(s) to be iterated over. Valid only before the iterator is
298+
closed.
296299
shape : tuple of ints
297300
Shape tuple, the shape of the iterator.
298301
value
@@ -331,12 +334,12 @@ def iter_add(x, y, out=None):
331334
332335
it = np.nditer([x, y, out], [],
333336
[['readonly'], ['readonly'], ['writeonly','allocate']])
337+
with it:
338+
while not it.finished:
339+
addop(it[0], it[1], out=it[2])
340+
it.iternext()
334341
335-
while not it.finished:
336-
addop(it[0], it[1], out=it[2])
337-
it.iternext()
338-
339-
return it.operands[2]
342+
return it.operands[2]
340343
341344
Here is an example outer product function::
342345
@@ -351,7 +354,7 @@ def outer_it(x, y, out=None):
351354
with it:
352355
for (a, b, c) in it:
353356
mulop(a, b, out=c)
354-
return it.operands[2]
357+
return it.operands[2]
355358
356359
>>> a = np.arange(2)+1
357360
>>> b = np.arange(3)+1
@@ -374,7 +377,7 @@ def luf(lamdaexpr, *args, **kwargs):
374377
while not it.finished:
375378
it[0] = lamdaexpr(*it[1:])
376379
it.iternext()
377-
return it.operands[0]
380+
return it.operands[0]
378381
379382
>>> a = np.arange(5)
380383
>>> b = np.ones(5)
@@ -430,6 +433,13 @@ def luf(lamdaexpr, *args, **kwargs):
430433
431434
"""))
432435

436+
add_newdoc('numpy.core', 'nditer', ('operands',
437+
"""
438+
operands[`Slice`]
439+
440+
The array(s) to be iterated over. Valid only before the iterator is closed.
441+
"""))
442+
433443
add_newdoc('numpy.core', 'nditer', ('debug_print',
434444
"""
435445
debug_print()

numpy/core/src/multiarray/nditer_pywrap.c

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,14 @@
2020

2121
typedef struct NewNpyArrayIterObject_tag NewNpyArrayIterObject;
2222

23-
enum NPYITER_CONTEXT {CONTEXT_NOTENTERED, CONTEXT_INSIDE, CONTEXT_EXITED};
24-
2523
struct NewNpyArrayIterObject_tag {
2624
PyObject_HEAD
2725
/* The iterator */
2826
NpyIter *iter;
2927
/* Flag indicating iteration started/stopped */
3028
char started, finished;
31-
/* iter must used as a context manager if writebackifcopy semantics used */
32-
char managed;
29+
/* iter operands cannot be referenced if iter is closed */
30+
npy_bool is_closed;
3331
/* Child to update for nested iteration */
3432
NewNpyArrayIterObject *nested_child;
3533
/* Cached values from the iterator */
@@ -89,7 +87,7 @@ npyiter_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds)
8987
if (self != NULL) {
9088
self->iter = NULL;
9189
self->nested_child = NULL;
92-
self->managed = CONTEXT_NOTENTERED;
90+
self->is_closed = 0;
9391
}
9492

9593
return (PyObject *)self;
@@ -1419,7 +1417,7 @@ static PyObject *npyiter_value_get(NewNpyArrayIterObject *self)
14191417
ret = npyiter_seq_item(self, 0);
14201418
}
14211419
else {
1422-
if (self->managed == CONTEXT_EXITED) {
1420+
if (self->is_closed) {
14231421
PyErr_SetString(PyExc_ValueError,
14241422
"Iterator is closed");
14251423
return NULL;
@@ -1454,7 +1452,7 @@ static PyObject *npyiter_operands_get(NewNpyArrayIterObject *self)
14541452
"Iterator is invalid");
14551453
return NULL;
14561454
}
1457-
if (self->managed == CONTEXT_EXITED) {
1455+
if (self->is_closed) {
14581456
PyErr_SetString(PyExc_ValueError,
14591457
"Iterator is closed");
14601458
return NULL;
@@ -1489,7 +1487,7 @@ static PyObject *npyiter_itviews_get(NewNpyArrayIterObject *self)
14891487
return NULL;
14901488
}
14911489

1492-
if (self->managed == CONTEXT_EXITED) {
1490+
if (self->is_closed) {
14931491
PyErr_SetString(PyExc_ValueError,
14941492
"Iterator is closed");
14951493
return NULL;
@@ -1517,7 +1515,8 @@ static PyObject *npyiter_itviews_get(NewNpyArrayIterObject *self)
15171515
static PyObject *
15181516
npyiter_next(NewNpyArrayIterObject *self)
15191517
{
1520-
if (self->iter == NULL || self->iternext == NULL || self->finished) {
1518+
if (self->iter == NULL || self->iternext == NULL ||
1519+
self->finished || self->is_closed) {
15211520
return NULL;
15221521
}
15231522

@@ -1912,7 +1911,7 @@ static PyObject *npyiter_dtypes_get(NewNpyArrayIterObject *self)
19121911
return NULL;
19131912
}
19141913

1915-
if (self->managed == CONTEXT_EXITED) {
1914+
if (self->is_closed) {
19161915
PyErr_SetString(PyExc_ValueError,
19171916
"Iterator is closed");
19181917
return NULL;
@@ -2014,7 +2013,7 @@ npyiter_seq_item(NewNpyArrayIterObject *self, Py_ssize_t i)
20142013
return NULL;
20152014
}
20162015

2017-
if (self->managed == CONTEXT_EXITED) {
2016+
if (self->is_closed) {
20182017
PyErr_SetString(PyExc_ValueError,
20192018
"Iterator is closed");
20202019
return NULL;
@@ -2104,7 +2103,7 @@ npyiter_seq_slice(NewNpyArrayIterObject *self,
21042103
return NULL;
21052104
}
21062105

2107-
if (self->managed == CONTEXT_EXITED) {
2106+
if (self->is_closed) {
21082107
PyErr_SetString(PyExc_ValueError,
21092108
"Iterator is closed");
21102109
return NULL;
@@ -2170,7 +2169,7 @@ npyiter_seq_ass_item(NewNpyArrayIterObject *self, Py_ssize_t i, PyObject *v)
21702169
return -1;
21712170
}
21722171

2173-
if (self->managed == CONTEXT_EXITED) {
2172+
if (self->is_closed) {
21742173
PyErr_SetString(PyExc_ValueError,
21752174
"Iterator is closed");
21762175
return -1;
@@ -2250,7 +2249,7 @@ npyiter_seq_ass_slice(NewNpyArrayIterObject *self, Py_ssize_t ilow,
22502249
return -1;
22512250
}
22522251

2253-
if (self->managed == CONTEXT_EXITED) {
2252+
if (self->is_closed) {
22542253
PyErr_SetString(PyExc_ValueError,
22552254
"Iterator is closed");
22562255
return -1;
@@ -2307,7 +2306,7 @@ npyiter_subscript(NewNpyArrayIterObject *self, PyObject *op)
23072306
return NULL;
23082307
}
23092308

2310-
if (self->managed == CONTEXT_EXITED) {
2309+
if (self->is_closed) {
23112310
PyErr_SetString(PyExc_ValueError,
23122311
"Iterator is closed");
23132312
return NULL;
@@ -2362,7 +2361,7 @@ npyiter_ass_subscript(NewNpyArrayIterObject *self, PyObject *op,
23622361
return -1;
23632362
}
23642363

2365-
if (self->managed == CONTEXT_EXITED) {
2364+
if (self->is_closed) {
23662365
PyErr_SetString(PyExc_ValueError,
23672366
"Iterator is closed");
23682367
return -1;
@@ -2402,11 +2401,10 @@ npyiter_enter(NewNpyArrayIterObject *self)
24022401
PyErr_SetString(PyExc_RuntimeError, "operation on non-initialized iterator");
24032402
return NULL;
24042403
}
2405-
if (self->managed == CONTEXT_EXITED) {
2406-
PyErr_SetString(PyExc_ValueError, "cannot reuse iterator after exit");
2404+
if (self->is_closed) {
2405+
PyErr_SetString(PyExc_ValueError, "cannot reuse closed iterator");
24072406
return NULL;
24082407
}
2409-
self->managed = CONTEXT_INSIDE;
24102408
Py_INCREF(self);
24112409
return (PyObject *)self;
24122410
}
@@ -2420,6 +2418,7 @@ npyiter_close(NewNpyArrayIterObject *self)
24202418
Py_RETURN_NONE;
24212419
}
24222420
ret = NpyIter_Close(iter);
2421+
self->is_closed = 1;
24232422
if (ret < 0) {
24242423
return NULL;
24252424
}
@@ -2429,7 +2428,6 @@ npyiter_close(NewNpyArrayIterObject *self)
24292428
static PyObject *
24302429
npyiter_exit(NewNpyArrayIterObject *self, PyObject *args)
24312430
{
2432-
self->managed = CONTEXT_EXITED;
24332431
/* even if called via exception handling, writeback any data */
24342432
return npyiter_close(self);
24352433
}

numpy/core/tests/test_nditer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2847,7 +2847,7 @@ def test_writebacks():
28472847
enter = it.__enter__
28482848
assert_raises(ValueError, enter)
28492849

2850-
def test_close():
2850+
def test_close_equivalent():
28512851
''' using a context amanger and using nditer.close are equivalent
28522852
'''
28532853
def add_close(x, y, out=None):
@@ -2856,8 +2856,10 @@ def add_close(x, y, out=None):
28562856
[['readonly'], ['readonly'], ['writeonly','allocate']])
28572857
for (a, b, c) in it:
28582858
addop(a, b, out=c)
2859+
ret = it.operands[2]
28592860
it.close()
2860-
return it.operands[2]
2861+
return ret
2862+
28612863
def add_context(x, y, out=None):
28622864
addop = np.add
28632865
it = np.nditer([x, y, out], [],
@@ -2871,6 +2873,13 @@ def add_context(x, y, out=None):
28712873
z = add_context(range(5), range(5))
28722874
assert_equal(z, range(0, 10, 2))
28732875

2876+
def test_close_raises():
2877+
it = np.nditer(np.arange(3))
2878+
assert_equal (next(it), 0)
2879+
it.close()
2880+
assert_raises(StopIteration, next, it)
2881+
assert_raises(ValueError, getattr, it, 'operands')
2882+
28742883
def test_warn_noclose():
28752884
a = np.arange(6, dtype='f4')
28762885
au = a.byteswap().newbyteorder()

0 commit comments

Comments
 (0)
0