8000 Merge pull request #11580 from mattip/ufunc-flag-refactor · numpy/numpy@97df928 · GitHub
[go: up one dir, main page]

Skip to content

Commit 97df928

Browse files
authored
Merge pull request #11580 from mattip/ufunc-flag-refactor
MAINT: refactor ufunc iter operand flags handling
2 parents bc81177 + 9afa2f7 commit 97df928

File tree

1 file changed

+113
-120
lines changed

1 file changed

+113
-120
lines changed

numpy/core/src/umath/ufunc_object.c

Lines changed: 113 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,49 @@ _find_array_prepare(ufunc_full_args args,
308308
return;
309309
}
310310

311+
#define NPY_UFUNC_DEFAULT_INPUT_FLAGS \
312+
NPY_ITER_READONLY | \
313+
NPY_ITER_ALIGNED | \
314+
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE
315+
316+
#define NPY_UFUNC_DEFAULT_OUTPUT_FLAGS \
317+
NPY_ITER_ALIGNED | \
318+
NPY_ITER_ALLOCATE | \
319+
NPY_ITER_NO_BROADCAST | \
320+
NPY_ITER_NO_SUBTYPE | \
321+
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE
322+
/*
323+
* Set per-operand flags according to desired input or output flags.
324+
* op_flags[i] for i in input (as determined by ufunc->nin) will be
325+
* merged with op_in_flags, perhaps overriding per-operand flags set
326+
* in previous stages.
327+
* op_flags[i] for i in output will be set to op_out_flags only if previously
328+
* unset.
329+
* The input flag behavior preserves backward compatibility, while the
330+
* output flag behaviour is the "correct" one for maximum flexibility.
331+
*/
332+
NPY_NO_EXPORT void
333+
_ufunc_setup_flags(PyUFuncObject *ufunc, npy_uint32 op_in_flags,
334+
npy_uint32 op_out_flags, npy_uint32 *op_flags)
335+
{
336+
int nin = ufunc->nin;
337+
int nout = ufunc->nout;
338+
int nop = nin + nout, i;
339+
/* Set up the flags */
340+
for (i = 0; i < nin; ++i) {
341+
op_flags[i] = ufunc->op_flags[i] | op_in_flags;
342+
/*
343+
* If READWRITE flag has been set for this operand,
344+
* then clear default READONLY flag
345+
*/
346+
if (op_flags[i] & (NPY_ITER_READWRITE | NPY_ITER_WRITEONLY)) {
347+
op_flags[i] &= ~NPY_ITER_READONLY;
348+
}
349+
}
350+
for (i = nin; i < nop; ++i) {
351+
op_flags[i] = ufunc->op_flags[i] ? ufunc->op_flags[i] : op_out_flags;
352+
}
353+
}
311354

312355
/*
313356
* This function analyzes the input arguments
@@ -1394,11 +1437,11 @@ iterator_loop(PyUFuncObject *ufunc,
13941437
PyObject **arr_prep,
13951438
ufunc_full_args full_args,
13961439
PyUFuncGenericFunction innerloop,
1397-
void *innerloopdata)
1440+
void *innerloopdata,
1441+
npy_uint32 *op_flags)
13981442
{
13991443
npy_intp i, nin = ufunc->nin, nout = ufunc->nout;
14001444
npy_intp nop = nin + nout;
1401-
npy_uint32 op_flags[NPY_MAXARGS];
14021445
NpyIter *iter;
14031446
char *baseptrs[NPY_MAXARGS];
14041447

@@ -1412,29 +1455,6 @@ iterator_loop(PyUFuncObject *ufunc,
14121455

14131456
NPY_BEGIN_THREADS_DEF;
14141457

1415-
/* Set up the flags */
1416-
for (i = 0; i < nin; ++i) {
1417-
op_flags[i] = NPY_ITER_READONLY |
1418-
NPY_ITER_ALIGNED |
1419-
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE;
1420-
/*
1421-
* If READWRITE flag has been set for this operand,
1422-
* then clear default READONLY flag
1423-
*/
1424-
op_flags[i] |= ufunc->op_flags[i];
1425-
if (op_flags[i] & (NPY_ITER_READWRITE | NPY_ITER_WRITEONLY)) {
1426-
op_flags[i] &= ~NPY_ITER_READONLY;
1427-
}
1428-
}
1429-
for (i = nin; i < nop; ++i) {
1430-
op_flags[i] = NPY_ITER_WRITEONLY |
1431-
NPY_ITER_ALIGNED |
1432-
NPY_ITER_ALLOCATE |
1433-
NPY_ITER_NO_BROADCAST |
1434-
NPY_ITER_NO_SUBTYPE |
1435-
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE;
1436-
}
1437-
14381458
iter_flags = ufunc->iter_flags |
14391459
NPY_ITER_EXTERNAL_LOOP |
14401460
NPY_ITER_REFS_OK |
@@ -1538,15 +1558,15 @@ iterator_loop(PyUFuncObject *ufunc,
15381558
}
15391559

15401560
/*
1561+
* ufunc - the ufunc to call
15411562
* trivial_loop_ok - 1 if no alignment, data conversion, etc required
1542-
* nin - number of inputs
1543-
* nout - number of outputs
1544-
* op - the operands (nin + nout of them)
1563+
* op - the operands (ufunc->nin + ufunc->nout of them)
1564+
* dtypes - the dtype of each operand
15451565
* order - the loop execution order/output memory order
15461566
* buffersize - how big of a buffer to use
15471567
* arr_prep - the __array_prepare__ functions for the outputs
1548-
* innerloop - the inner loop function
1549-
* innerloopdata - data to pass to the inner loop
1568+
* full_args - the original input, output PyObject *
1569+
* op_flags - per-operand flags, a combination of NPY_ITER_* constants
15501570
*/
15511571
static int
15521572
execute_legacy_ufunc_loop(PyUFuncObject *ufunc,
@@ -1556,7 +1576,8 @@ execute_legacy_ufunc_loop(PyUFuncObject *ufunc,
15561576
NPY_ORDER order,
15571577
npy_intp buffersize,
15581578
PyObject **arr_prep,
1559-
ufunc_full_args full_args)
1579+
ufunc_full_args full_args,
1580+
npy_uint32 *op_flags)
15601581
{
15611582
npy_intp nin = ufunc->nin, nout = ufunc->nout;
15621583
PyUFuncGenericFunction innerloop;
@@ -1691,7 +1712,7 @@ execute_legacy_ufunc_loop(PyUFuncObject *ufunc,
16911712
NPY_UF_DBG_PRINT("iterator loop\n");
16921713
if (iterator_loop(ufunc, op, dtypes, order,
16931714
buffersize, arr_prep, full_args,
1694-
innerloop, innerloopdata) < 0) {
1715+
innerloop, innerloopdata, op_flags) < 0) {
16951716
return -1;
16961717
}
16971718

@@ -1717,14 +1738,13 @@ execute_fancy_ufunc_loop(PyUFuncObject *ufunc,
17171738
NPY_ORDER order,
17181739
npy_intp buffersize,
17191740
PyObject **arr_prep,
1720-
ufunc_full_args full_args)
1741+
ufunc_full_args full_args,
1742+
npy_uint32 *op_flags)
17211743
{
17221744
int i, nin = ufunc->nin, nout = ufunc->nout;
17231745
int nop = nin + nout;
1724-
npy_uint32 op_flags[NPY_MAXARGS];
17251746
NpyIter *iter;
17261747
int needs_api;
1727-
npy_intp default_op_in_flags = 0, default_op_out_flags = 0;
17281748

17291749
NpyIter_IterNextFunc *iternext;
17301750
char **dataptr;
@@ -1734,48 +1754,10 @@ execute_fancy_ufunc_loop(PyUFuncObject *ufunc,
17341754
PyArrayObject **op_it;
17351755
npy_uint32 iter_flags;
17361756

1737-
if (wheremask != NULL) {
1738-
if (nop + 1 > NPY_MAXARGS) {
1739-
PyErr_SetString(PyExc_ValueError,
1740-
"Too many operands when including where= parameter");
1741-
return -1;
1742-
}
1743-
op[nop] = wheremask;
1744-
dtypes[nop] = NULL;
1745-
default_op_out_flags |= NPY_ITER_WRITEMASKED;
1746-
}
1747-
1748-
/* Set up the flags */
1749-
for (i = 0; i < nin; ++i) {
1750-
op_flags[i] = default_op_in_flags |
1751-
NPY_ITER_READONLY |
1752-
NPY_ITER_ALIGNED |
1753-
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE;
1754-
/*
1755-
* If READWRITE flag has been set for this operand,
1756-
* then clear default READONLY flag
1757-
*/
1758-
op_flags[i] |= ufunc->op_flags[i];
1759-
if (op_flags[i] & (NPY_ITER_READWRITE | NPY_ITER_WRITEONLY)) {
1760-
op_flags[i] &= ~NPY_ITER_READONLY;
1761-
}
1762-
}
17631757
for (i = nin; i < nop; ++i) {
1764-
/*
1765-
* We don't write to all elements, and the iterator may make
1766-
* UPDATEIFCOPY temporary copies. The output arrays (unless they are
1767-
* allocated by the iterator itself) must be considered READWRITE by the
1768-
* iterator, so that the elements we don't write to are copied to the
1769-
* possible temporary array.
1770-
*/
1771-
op_flags[i] = default_op_out_flags |
1772-
(op[i] != NULL ? NPY_ITER_READWRITE : NPY_ITER_WRITEONLY) |
1773-
NPY_ITER_ALIGNED |
1774-
NPY_ITER_ALLOCATE |
1775-
NPY_ITER_NO_BROADCAST |
1776-
NPY_ITER_NO_SUBTYPE |
1777-
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE;
1758+
op_flags[i] |= (op[i] != NULL ? NPY_ITER_READWRITE : NPY_ITER_WRITEONLY);
17781759
}
1760+
17791761
if (wheremask != NULL) {
17801762
op_flags[nop] = NPY_ITER_READONLY | NPY_ITER_ARRAYMASK;
17811763
}
@@ -2785,6 +2767,18 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
27852767
if (retval < 0) {
27862768
goto fail;
27872769
}
2770+
/*
2771+
* We don't write to all elements, and the iterator may make
2772+
* UPDATEIFCOPY temporary copies. The output arrays (unless they are
2773+
* allocated by the iterator itself) must be considered READWRITE by the
2774+
* iterator, so that the elements we don't write to are copied to the
2775+
* possible temporary array.
2776+
*/
2777+
_ufunc_setup_flags(ufunc, NPY_ITER_COPY | NPY_UFUNC_DEFAULT_INPUT_FLAGS,
2778+
NPY_ITER_UPDATEIFCOPY |
2779+
NPY_ITER_READWRITE |
2780+
NPY_UFUNC_DEFAULT_OUTPUT_FLAGS,
2781+
op_flags);
27882782
/* For the generalized ufunc, we get the loop right away too */
27892783
retval = ufunc->legacy_inner_loop_selector(ufunc, dtypes,
27902784
&innerloop, &innerloopdata, &needs_api);
@@ -2827,28 +2821,6 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
28272821
* Set up the iterator per-op flags. For generalized ufuncs, we
28282822
* can't do buffering, so must COPY or UPDATEIFCOPY.
28292823
*/
2830-
for (i = 0; i < nin; ++i) {
2831-
op_flags[i] = NPY_ITER_READONLY |
2832-
NPY_ITER_COPY |
2833-
NPY_ITER_ALIGNED |
2834-
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE;
2835-
/*
2836-
* If READWRITE flag has been set for this operand,
2837-
* then clear default READONLY flag
2838-
*/
2839-
op_flags[i] |= ufunc->op_flags[i];
2840-
if (op_flags[i] & (NPY_ITER_READWRITE | NPY_ITER_WRITEONLY)) {
2841-
op_flags[i] &= ~NPY_ITER_READONLY;
2842-
}
2843-
}
2844-
for (i = nin; i < nop; ++i) {
2845-
op_flags[i] = NPY_ITER_READWRITE|
2846-
NPY_ITER_UPDATEIFCOPY|
2847-
NPY_ITER_ALIGNED|
2848-
NPY_ITER_ALLOCATE|
2849-
NPY_ITER_NO_BROADCAST|
2850-
NPY_ITER_OVERLAP_ASSUME_ELEMENTWISE;
2851-
}
28522824

28532825
iter_flags = ufunc->iter_flags |
28542826
NPY_ITER_MULTI_INDEX |
@@ -3097,7 +3069,8 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
30973069
int i, nop;
30983070
const char *ufunc_name;
30993071
int retval = -1, subok = 1;
3100-
int need_fancy = 0;
3072+
npy_uint32 op_flags[NPY_MAXARGS];
3073+
npy_intp default_op_out_flags;
31013074

31023075
PyArray_Descr *dtypes[NPY_MAXARGS];
31033076

@@ -3156,13 +3129,6 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
31563129
return retval;
31573130
}
31583131

3159-
/*
3160-
* Use the masked loop if a wheremask was specified.
3161-
*/
3162-
if (wheremask != NULL) {
3163-
need_fancy = 1;
3164-
}
3165-
31663132
/* Get the buffersize and errormask */
31673133
if (_get_bufsize_errmask(extobj, ufunc_name, &buffersize, &errormask) < 0) {
31683134
retval = -1;
@@ -3177,16 +3143,20 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
31773143
goto fail;
31783144
}
31793145

3180-
/* Only do the trivial loop check for the unmasked version. */
3181-
if (!need_fancy) {
3182-
/*
3183-
* This checks whether a trivial loop is ok, making copies of
3184-
* scalar and one dimensional operands if that will help.
3185-
*/
3186-
trivial_loop_ok = check_for_trivial_loop(ufunc, op, dtypes, buffersize);
3187-
if (trivial_loop_ok < 0) {
3188-
goto fail;
3189-
}
3146+
if (wheremask != NULL) {
3147+
/* Set up the flags. */
3148+
default_op_out_flags = NPY_ITER_NO_SUBTYPE |
3149+
NPY_ITER_WRITEMASKED |
3150+
NPY_UFUNC_DEFAULT_OUTPUT_FLAGS;
3151+
_ufunc_setup_flags(ufunc, NPY_UFUNC_DEFAULT_INPUT_FLAGS,
3152+
default_op_out_flags, op_flags);
3153+
}
3154+
else {
3155+
/* Set up the flags. */
3156+
default_op_out_flags = NPY_ITER_WRITEONLY |
3157+
NPY_UFUNC_DEFAULT_OUTPUT_FLAGS;
3158+
_ufunc_setup_flags(ufunc, NPY_UFUNC_DEFAULT_INPUT_FLAGS,
3159+
default_op_out_flags, op_flags);
31903160
}
31913161

31923162
#if NPY_UF_DBG_TRACING
@@ -3214,23 +3184,46 @@ PyUFunc_GenericFunction(PyUFuncObject *ufunc,
32143184
_find_array_prepare(full_args, arr_prep, nin, nout);
32153185
}
32163186

3217-
/* Start with the floating-point exception flags cleared */
3218-
npy_clear_floatstatus_barrier((char*)&ufunc);
32193187

32203188
/* Do the ufunc loop */
3221-
if (need_fancy) {
3189+
if (wheremask != NULL) {
32223190
NPY_UF_DBG_PRINT("Executing fancy inner loop\n");
32233191

3192+
if (nop + 1 > NPY_MAXARGS) {
3193+
PyErr_SetString(PyExc_ValueError,
3194+
"Too many operands when including where= parameter");
3195+
return -1;
3196+
}
3197+
op[nop] = wheremask;
3198+
dtypes[nop] = NULL;
3199+
3200+
/* Set up the flags */
3201+
3202+
npy_clear_floatstatus_barrier((char*)&ufunc);
32243203
retval = execute_fancy_ufunc_loop(ufunc, wheremask,
32253204
op, dtypes, order,
3226-
buffersize, arr_prep, full_args);
3205+
buffersize, arr_prep, full_args, op_flags);
32273206
}
32283207
else {
32293208
NPY_UF_DBG_PRINT("Executing legacy inner loop\n");
32303209

3210+
/*
3211+
* This checks whether a trivial loop is ok, making copies of
3212+
* scalar and one dimensional operands if that will help.
3213+
* Since it requires dtypes, it can only be called after
3214+
* ufunc->type_resolver
3215+
*/
3216+
trivial_loop_ok = check_for_trivial_loop(ufunc, op, dtypes, buffersize);
3217+
if (trivial_loop_ok < 0) {
3218+
goto fail;
3219+
}
3220+
3221+
/* check_for_trivial_loop on half-floats can overflow */
3222+
npy_clear_floatstatus_barrier((char*)&ufunc);
3223+
32313224
retval = execute_legacy_ufunc_loop(ufunc, trivial_loop_ok,
32323225
op, dtypes, order,
3233-
buffersize, arr_prep, full_args);
3226+
buffersize, arr_prep, full_args, op_flags);
32343227
}
32353228
if (retval < 0) {
32363229
goto fail;

0 commit comments

Comments
 (0)
0