8000 Merge pull request #19258 from seberg/maint-ufunc-refactor-iterator-loop · numpy/numpy@7d8fada · GitHub
[go: up one dir, main page]

Skip to content

Commit 7d8fada

Browse files
authored
Merge pull request #19258 from seberg/maint-ufunc-refactor-iterator-loop
MAINT: Refactor and simplify the main ufunc iterator loop code
2 parents f079a56 + 23600c6 commit 7d8fada

File tree

1 file changed

+61
-55
lines changed

1 file changed

+61
-55
lines changed

numpy/core/src/umath/ufunc_object.c

Lines changed: 61 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,23 +1263,10 @@ iterator_loop(PyUFuncObject *ufunc,
12631263
void *innerloopdata,
12641264
npy_uint32 *op_flags)
12651265
{
1266-
npy_intp i, nin = ufunc->nin, nout = ufunc->nout;
1267-
npy_intp nop = nin + nout;
1268-
NpyIter *iter;
1269-
char *baseptrs[NPY_MAXARGS];
1270-
1271-
NpyIter_IterNextFunc *iternext;
1272-
char **dataptr;
1273-
npy_intp *stride;
1274-
npy_intp *count_ptr;
1275-
int needs_api;
1276-
1277-
PyArrayObject **op_it;
1278-
npy_uint32 iter_flags;
1279-
1280-
NPY_BEGIN_THREADS_DEF;
1266+
int nin = ufunc->nin, nout = ufunc->nout;
1267+
int nop = nin + nout;
12811268

1282-
iter_flags = ufunc->iter_flags |
1269+
npy_uint32 iter_flags = ufunc->iter_flags |
12831270
NPY_ITER_EXTERNAL_LOOP |
12841271
NPY_ITER_REFS_OK |
12851272
NPY_ITER_ZEROSIZE_OK |
@@ -1288,16 +1275,17 @@ iterator_loop(PyUFuncObject *ufunc,
12881275
NPY_ITER_DELAY_BUFALLOC |
12891276
NPY_ITER_COPY_IF_OVERLAP;
12901277

1291-
/* Call the __array_prepare__ functions for already existing output arrays.
1278+
/*
1279+
* Call the __array_prepare__ functions for already existing output arrays.
12921280
* Do this before creating the iterator, as the iterator may UPDATEIFCOPY
12931281
* some of them.
12941282
*/
1295-
for (i = 0; i < nout; ++i) {
1283+
for (int i = 0; i < nout; i++) {
12961284
if (op[nin+i] == NULL) {
12971285
continue;
12981286
}
12991287
if (prepare_ufunc_output(ufunc, &op[nin+i],
1300-
arr_prep[i], full_args, i) < 0) {
1288+
arr_prep[i], full_args, i) < 0) {
13011289
return -1;
13021290
}
13031291
}
@@ -1307,7 +1295,7 @@ iterator_loop(PyUFuncObject *ufunc,
13071295
* were already checked, we use the casting rule 'unsafe' which
13081296
* is faster to calculate.
13091297
*/
1310-
iter = NpyIter_AdvancedNew(nop, op,
1298+
NpyIter *iter = NpyIter_AdvancedNew(nop, op,
13111299
iter_flags,
13121300
order, NPY_UNSAFE_CASTING,
13131301
op_flags, dtype,
@@ -1316,16 +1304,20 @@ iterator_loop(PyUFuncObject *ufunc,
13161304
return -1;
13171305
}
13181306

1319-
/* Copy any allocated outputs */
1320-
op_it = NpyIter_GetOperandArray(iter);
1321-
for (i = 0; i < nout; ++i) {
1322-
if (op[nin+i] == NULL) {
1323-
op[nin+i] = op_it[nin+i];
1324-
Py_INCREF(op[nin+i]);
1307+
NPY_UF_DBG_PRINT("Made iterator\n");
1308+
1309+
/* Call the __array_prepare__ functions for newly allocated arrays */
1310+
PyArrayObject **op_it = NpyIter_GetOperandArray(iter);
1311+
char *baseptrs[NPY_MAXARGS];
1312+
1313+
for (int i = 0; i < nout; ++i) {
1314+
if (op[nin + i] == NULL) {
1315+
op[nin + i] = op_it[nin + i];
1316+
Py_INCREF(op[nin + i]);
13251317

13261318
/* Call the __array_prepare__ functions for the new array */
1327-
if (prepare_ufunc_output(ufunc, &op[nin+i],
1328-
arr_prep[i], full_args, i) < 0) {
1319+
if (prepare_ufunc_output(ufunc,
1320+
&op[nin + i], arr_prep[i], full_args, i) < 0) {
13291321
NpyIter_Deallocate(iter);
13301322
return -1;
13311323
}
@@ -1340,45 +1332,59 @@ iterator_loop(PyUFuncObject *ufunc,
13401332
* with other operands --- the op[nin+i] array passed to it is newly
13411333
* allocated and doesn't have any overlap.
13421334
*/
1343-
baseptrs[nin+i] = PyArray_BYTES(op[nin+i]);
1335+
baseptrs[nin + i] = PyArray_BYTES(op[nin + i]);
13441336
}
13451337
else {
1346-
baseptrs[nin+i] = PyArray_BYTES(op_it[nin+i]);
1338+
baseptrs[nin + i] = PyArray_BYTES(op_it[nin + i]);
13471339
}
13481340
}
1349-
13501341
/* Only do the loop if the iteration size is non-zero */
1351-
if (NpyIter_GetIterSize(iter) != 0) {
1352-
/* Reset the iterator with the base pointers from possible __array_prepare__ */
1353-
for (i = 0; i < nin; ++i) {
1354-
baseptrs[i] = PyArray_BYTES(op_it[i]);
1355-
}
1356-
if (NpyIter_ResetBasePointers(iter, baseptrs, NULL) != NPY_SUCCEED) {
1357-
NpyIter_Deallocate(iter);
1342+
npy_intp full_size = NpyIter_GetIterSize(iter);
1343+
if (full_size == 0) {
1344+
if (!NpyIter_Deallocate(iter)) {
13581345
return -1;
13591346
}
1347+
return 0;
1348+
}
13601349

1361-
/* Get the variables needed for the loop */
1362-
iternext = NpyIter_GetIterNext(iter, NULL);
1363-
if (iternext == NULL) {
1364-
NpyIter_Deallocate(iter);
1365-
return -1;
1366-
}
1367-
dataptr = NpyIter_GetDataPtrArray(iter);
1368-
stride = NpyIter_GetInnerStrideArray(iter);
1369-
count_ptr = NpyIter_GetInnerLoopSizePtr(iter);
1370-
needs_api = NpyIter_IterationNeedsAPI(iter);
1350+
/*
1351+
* Reset the iterator with the base pointers possibly modified by
1352+
* `__array_prepare__`.
1353+
*/
1354+
for (int i = 0; i < nin; i++) {
1355+
baseptrs[i] = PyArray_BYTES(op_it[i]);
1356+
}
1357+
if (NpyIter_ResetBasePointers(iter, baseptrs, NULL) != NPY_SUCCEED) {
1358+
NpyIter_Deallocate(iter);
1359+
return -1;
1360+
}
13711361

1372-
NPY_BEGIN_THREADS_NDITER(iter);
1362+
/* Get the variables needed for the loop */
1363+
NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL);
1364+
if (iternext == NULL) {
1365+
NpyIter_Deallocate(iter);
1366+
return -1;
1367+
}
1368+
char **dataptr = NpyIter_GetDataPtrArray(iter);
1369+
npy_intp *strides = NpyIter_GetInnerStrideArray(iter);
1370+
npy_intp *countptr = NpyIter_GetInnerLoopSizePtr(iter);
1371+
int needs_api = NpyIter_IterationNeedsAPI(iter);
13731372

1374-
/* Execute the loop */
1375-
do {
1376-
NPY_UF_DBG_PRINT1("iterator loop count %d\n", (int)*count_ptr);
1377-
innerloop(dataptr, count_ptr, stride, innerloopdata);
1378-
} while (!(needs_api && PyErr_Occurred()) && iternext(iter));
1373+
NPY_BEGIN_THREADS_DEF;
13791374

1380-
NPY_END_THREADS;
1375+
if (!needs_api) {
1376+
NPY_BEGIN_THREADS_THRESHOLDED(full_size);
13811377
}
1378+
1379+
NPY_UF_DBG_PRINT("Actual inner loop:\n");
1380+
/* Execute the loop */
1381+
do {
1382+
NPY_UF_DBG_PRINT1("iterator loop count %d\n", (int)*count_ptr);
1383+
innerloop(dataptr, countptr, strides, innerloopdata);
1384+
} while (!(needs_api && PyErr_Occurred()) && iternext(iter));
1385+
1386+
NPY_END_THREADS;
1387+
13821388
/*
13831389
* Currently `innerloop` may leave an error set, in this case
13841390
* NpyIter_Deallocate will always return an error as well.

0 commit comments

Comments
 (0)
0