10000 BUG: Fix problems with ndindex and nditer · numpy/numpy@3043864 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3043864

Browse files
committed
BUG: Fix problems with ndindex and nditer
This fixes an issue with ndindex shape tuple recognition, and an issue in the nditer where scalar input did not produce an empty index tuple. To be able to fix nditer, an extra flag has been added: NPY_ITFLAG_SCALAR and a new function NpyIter_IsScalar has been added to the nditer API. Also a few tests have been added to make sure the ndindex behaves as intended.
1 parent d8988ab commit 3043864

File tree

7 files changed

+39
-19
lines changed

7 files changed

+39
-19
lines changed

numpy/core/code_generators/numpy_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@
328328
'PyDataMem_FREE': 289,
329329
'PyDataMem_RENEW': 290,
330330
'PyDataMem_SetEventHook': 291,
331+
'NpyIter_IsScalar': 292,
331332
}
332333

333334
ufunc_types_api = {

numpy/core/src/multiarray/nditer_api.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -844,6 +844,15 @@ NpyIter_IsGrowInner(NpyIter *iter)
844844
return (NIT_ITFLAGS(iter)&NPY_ITFLAG_GROWINNER) != 0;
845845
}
846846

847+
/*NUMPY_API
848+
* Whether the iterator output is scalar
849+
*/
850+
NPY_NO_EXPORT npy_bool
851+
NpyIter_IsScalar(NpyIter *iter)
852+
{
853+
return (NIT_ITFLAGS(iter)&NPY_ITFLAG_SCALAR) != 0;
854+
}
855+
847856
/*NUMPY_API
848857
* Gets the size of the buffer, or 0 if buffering is not enabled
849858
*/

numpy/core/src/multiarray/nditer_constr.c

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ static int
5454
npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itflags,
5555
char **op_dataptr,
5656
npy_uint32 *op_flags, int **op_axes,
57-
npy_intp *itershape,
58-
int output_scalars);
57+
npy_intp *itershape);
5958
static void
6059
npyiter_replace_axisdata(NpyIter *iter, int iop,
6160
PyArrayObject *op,
@@ -74,8 +73,7 @@ npyiter_find_best_axis_ordering(NpyIter *iter);
7473
static PyArray_Descr *
7574
npyiter_get_common_dtype(int nop, PyArrayObject **op,
7675
npyiter_opitflags *op_itflags, PyArray_Descr **op_dtype,
77-
PyArray_Descr **op_request_dtypes,
78-
int only_inputs, int output_scalars);
76+
PyArray_Descr **op_request_dtypes, int only_inputs);
7977
static PyArrayObject *
8078
npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype,
8179
npy_uint32 flags, npyiter_opitflags *op_itflags,
@@ -86,7 +84,7 @@ npyiter_allocate_arrays(NpyIter *iter,
8684
npy_uint32 flags,
8785
PyArray_Descr **op_dtype, PyTypeObject *subtype,
8886
npy_uint32 *op_flags, npyiter_opitflags *op_itflags,
89-
int **op_axes, int output_scalars);
87+
int **op_axes);
9088
static void
9189
npyiter_get_priority_subtype(int nop, PyArrayObject **op,
9290
npyiter_opitflags *op_itflags,
@@ -123,7 +121,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
123121
npy_int8 *perm;
124122
NpyIter_BufferData *bufferdata = NULL;
125123
int any_allocate = 0, any_missing_dtypes = 0,
126-
output_scalars = 0, need_subtype = 0;
124+
need_subtype = 0;
127125

128126
/* The subtype for automatically allocated outputs */
129127
double subtype_priority = NPY_PRIORITY;
@@ -177,7 +175,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
177175

178176
/* If 'ndim' is zero, any outputs should be scalars */
179177
if (ndim == 0) {
180-
output_scalars = 1;
178+
itflags |= NPY_ITFLAG_SCALAR;
181179
ndim = 1;
182180
}
183181

@@ -231,8 +229,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
231229

232230
/* Fill in the AXISDATA arrays and set the ITERSIZE field */
233231
if (!npyiter_fill_axisdata(iter, flags, op_itflags, op_dataptr,
234-
op_flags, op_axes, itershape,
235-
output_scalars)) {
232+
op_flags, op_axes, itershape)) {
236233
NpyIter_Deallocate(iter);
237234
return NULL;
238235
}
@@ -338,8 +335,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
338335
dtype = npyiter_get_common_dtype(nop, op,
339336
op_itflags, op_dtype,
340337
op_request_dtypes,
341-
only_inputs,
342-
output_scalars);
338+
only_inputs);
343339
if (dtype == NULL) {
344340
NpyIter_Deallocate(iter);
345341
return NULL;
@@ -389,7 +385,7 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
389385
* done now using a memory layout matching the iterator.
390386
*/
391387
if (!npyiter_allocate_arrays(iter, flags, op_dtype, subtype, op_flags,
392-
op_itflags, op_axes, output_scalars)) {
388+
op_itflags, op_axes)) {
393389
NpyIter_Deallocate(iter);
394390
return NULL;
395391
}
@@ -1437,8 +1433,7 @@ static int
14371433
npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itflags,
14381434
char **op_dataptr,
14391435
npy_uint32 *op_flags, int **op_axes,
1440-
npy_intp *itershape,
1441-
int output_scalars)
1436+
npy_intp *itershape)
14421437
{
14431438
npy_uint32 itflags = NIT_ITFLAGS(iter);
14441439
int idim, ndim = NIT_NDIM(iter);
@@ -1558,7 +1553,7 @@ npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itf
15581553
ondim = PyArray_NDIM(op_cur);
15591554
if (bshape == 1) {
15601555
strides[iop] = 0;
1561-
if (idim >= ondim && !output_scalars &&
1556+
if (idim >= ondim && !(itflags & NPY_ITFLAG_SCALAR) &&
15621557
(op_flags[iop] & NPY_ITER_NO_BROADCAST)) {
15631558
goto operand_different_than_broadcast;
15641559
}
@@ -2393,8 +2388,7 @@ npyiter_find_best_axis_ordering(NpyIter *iter)
23932388
static PyArray_Descr *
23942389
npyiter_get_common_dtype(int nop, PyArrayObject **op,
23952390
npyiter_opitflags *op_itflags, PyArray_Descr **op_dtype,
2396-
PyArray_Descr **op_request_dtypes,
2397-
int only_inputs, int output_scalars)
2391+
PyArray_Descr **op_request_dtypes, int only_inputs)
23982392
{
23992393
int iop;
24002394
npy_intp narrs = 0, ndtypes = 0;
@@ -2693,7 +2687,7 @@ npyiter_allocate_arrays(NpyIter *iter,
26932687
npy_uint32 flags,
26942688
PyArray_Descr **op_dtype, PyTypeObject *subtype,
26952689
npy_uint32 *op_flags, npyiter_opitflags *op_itflags,
2696-
int **op_axes, int output_scalars)
2690+
int **op_axes)
26972691
{
26982692
npy_uint32 itflags = NIT_ITFLAGS(iter);
26992693
int idim, ndim = NIT_NDIM(iter);
@@ -2724,7 +2718,7 @@ npyiter_allocate_arrays(NpyIter *iter,
27242718
if (op[iop] == NULL) {
27252719
PyArrayObject *out;
27262720
PyTypeObject *op_subtype;
2727-
int ondim = output_scalars ? 0 : ndim;
2721+
int ondim = (itflags & NPY_ITFLAG_SCALAR) ? 0 : ndim;
27282722

27292723
/* Check whether the subtype was disabled */
27302724
op_subtype = (op_flags[iop] & NPY_ITER_NO_SUBTYPE) ?

numpy/core/src/multiarray/nditer_impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@
101101
#define NPY_ITFLAG_REDUCE 0x1000
102102
/* Reduce iteration doesn't need to recalculate reduce loops next time */
103103
#define NPY_ITFLAG_REUSE_REDUCE_LOOPS 0x2000
104+
/* The iterator output is scalar */
105+
#define NPY_ITFLAG_SCALAR 0x4000
104106

105107
/* Internal iterator per-operand iterator flags */
106108

@@ -215,6 +217,7 @@ typedef npy_int16 npyiter_opitflags;
215217
&(iter)->iter_flexdata + NIT_RESETDATAPTR_OFFSET(itflags, ndim, nop)))
216218
#define NIT_BASEOFFSETS(iter) ((npy_intp *)( \
217219
&(iter)->iter_flexdata + NIT_BASEOFFSETS_OFFSET(itflags, ndim, nop)))
220+
218221
#define NIT_OPERANDS(iter) ((PyArrayObject **)( \
219222
&(iter)->iter_flexdata + NIT_OPERANDS_OFFSET(itflags, ndim, nop)))
220223
#define NIT_OPITFLAGS(iter) ((npyiter_opitflags *)( \

numpy/core/src/multiarray/nditer_pywrap.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,9 @@ static PyObject *npyiter_multi_index_get(NewNpyArrayIterObject *self)
15421542
}
15431543

15441544
if (self->get_multi_index != NULL) {
1545+
if (NpyIter_IsScalar(self->iter)) {
1546+
return PyTuple_New(0);
1547+
}
15451548
ndim = NpyIter_GetNDim(self->iter);
15461549
self->get_multi_index(self->iter, multi_index);
15471550
ret = PyTuple_New(ndim);
@@ -1968,6 +1971,7 @@ npyiter_seq_item(NewNpyArrayIterObject *self, Py_ssize_t i)
19681971
return NULL;
19691972
}
19701973

1974+
19711975
if (NpyIter_HasDelayedBufAlloc(self->iter)) {
19721976
PyErr_SetString(PyExc_ValueError,
19731977
"Iterator construction used delayed buffer allocation, "

numpy/lib/index_tricks.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,9 @@ class ndindex(object):
533533
534534
"""
535535
def __init__(self, *shape):
536+
# Accept shapes in the form f(x, y, ..) as well as f((x ,y, ..))
537+
if len(shape) == 1 and isinstance(shape[0], tuple):
538+
shape = shape[0]
536539
x = as_strided(_nx.zeros(1), shape=shape, strides=_nx.zeros_like(shape))
537540
self._it = _nx.nditer(x, flags=['multi_index'], order='C')
538541

numpy/lib/tests/test_index_tricks.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,12 @@ def test_ndindex():
241241
x = list(np.ndindex(1, 2, 3))
242242
expected = [ix for ix, e in np.ndenumerate(np.zeros((1, 2, 3)))]
243243
assert_array_equal(x, expected)
244+
# Packed as well as unpacked tuple are acceptable
245+
y = list(np.ndindex((1, 2, 3)))
246+
assert_array_equal(x, y)
247+
# Empty shape gives empty index
248+
z = list(np.ndindex(()))
249+
assert_equal(z, [()])
244250

245251

246252
if __name__ == "__main__":

0 commit comments

Comments
 (0)
0