8000 ENH: Allow np.nditer to support scalar op_axes · numpy/numpy@acce195 · GitHub
[go: up one dir, main page]

Skip to content

Commit acce195

Browse files
committed
ENH: Allow np.nditer to support scalar op_axes
Also uses oa_ndim == -1 to signal no op_axes were given. This is slightly cleaner inside pywrap itself and is a cleaner signal for the iterator.
1 parent 2e8fcc0 commit acce195

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

numpy/core/src/multiarray/nditer_pywrap.c

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ NpyIter_GlobalFlagsConverter(PyObject *flags_in, npy_uint32 *flags)
9595
npy_uint32 flag;
9696

9797
if (flags_in == NULL || flags_in == Py_None) {
98-
*flags = 0;
9998
return 1;
10099
}
101100

@@ -526,7 +525,7 @@ npyiter_convert_op_axes(PyObject *op_axes_in, npy_intp nop,
526525
return 0;
527526
}
528527

529-
*oa_ndim = 0;
528+
*oa_ndim = -1;
530529

531530
/* Copy the tuples into op_axes */
532531
for (iop = 0; iop < nop; ++iop) {
@@ -545,13 +544,8 @@ npyiter_convert_op_axes(PyObject *op_axes_in, npy_intp nop,
545544
Py_DECREF(a);
546545
return 0;
547546
}
548-
if (*oa_ndim == 0) {
547+
if (*oa_ndim == -1) {
549548
*oa_ndim = PySequence_Size(a);
550-
if (*oa_ndim == 0) {
551-
PyErr_SetString(PyExc_ValueError,
552-
"op_axes must have at least one dimension");
553-
return 0;
554-
}
555549
if (*oa_ndim > NPY_MAXDIMS) {
556550
PyErr_SetString(PyExc_ValueError,
557551
"Too many dimensions in op_axes");
@@ -575,7 +569,7 @@ npyiter_convert_op_axes(PyObject *op_axes_in, npy_intp nop,
575569
op_axes[iop][idim] = -1;
576570
}
577571
else {
578-
op_axes[iop][idim] = PyInt_AsLong(v);
572+
op_axes[iop][idim] = PyArray_PyIntAsInt(v);
579573
if (op_axes[iop][idim]==-1 &&
580574
PyErr_Occurred()) {
581575
Py_DECREF(a);
@@ -589,7 +583,7 @@ npyiter_convert_op_axes(PyObject *op_axes_in, npy_intp nop,
589583
}
590584
}
591585

592-
if (*oa_ndim == 0) {
586+
if (*oa_ndim == -1) {
593587
PyErr_SetString(PyExc_ValueError,
594588
"If op_axes is provided, at least one list of axes "
595589
"must be contained within it");
@@ -726,7 +720,7 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
726720
NPY_CASTING casting = NPY_SAFE_CASTING;
727721
npy_uint32 op_flags[NPY_MAXARGS];
728722
PyArray_Descr *op_request_dtypes[NPY_MAXARGS];
729-
int oa_ndim = 0;
723+
int oa_ndim = -1;
730724
int op_axes_arrays[NPY_MAXARGS][NPY_MAXDIMS];
731725
int *op_axes[NPY_MAXARGS];
732726
PyArray_Dims itershape = {NULL, 0};
@@ -784,7 +778,7 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
784778
}
785779

786780
if (itershape.len > 0) {
787-
if (oa_ndim == 0) {
781+
if (oa_ndim == -1) {
788782
oa_ndim = itershape.len;
789783
memset(op_axes, 0, sizeof(op_axes[0]) * nop);
790784
}
@@ -800,10 +794,9 @@ npyiter_init(NewNpyArrayIterObject *self, PyObject *args, PyObject *kwds)
800794
itershape.ptr = NULL;
801795
}
802796

803-
804797
self->iter = NpyIter_AdvancedNew(nop, op, flags, order, casting, op_flags,
805798
op_request_dtypes,
806-
oa_ndim, oa_ndim > 0 ? op_axes : NULL,
799+
oa_ndim, oa_ndim >= 0 ? op_axes : NULL,
807800
itershape.ptr,
808801
buffersize);
809802

@@ -860,7 +853,7 @@ NpyIter_NestedIters(PyObject *NPY_UNUSED(self),
860853

861854
int iop, nop = 0, inest, nnest = 0;
862855
PyArrayObject *op[NPY_MAXARGS];
863-
npy_uint32 flags = 0, flags_inner = 0;
856+
npy_uint32 flags = 0, flags_inner;
864857
NPY_ORDER order = NPY_KEEPORDER;
865858
NPY_CASTING casting = NPY_SAFE_CASTING;
866859
npy_uint32 op_flags[NPY_MAXARGS], op_flags_inner[NPY_MAXARGS];

0 commit comments

Comments
 (0)
0