8000 ENH: ufunc: Add a mask dtype parameter to the masked ufunc loop selector · numpy/numpy@dbabb8e · GitHub
[go: up one dir, main page]

Skip to content

Commit dbabb8e

Browse files
committed
ENH: ufunc: Add a mask dtype parameter to the masked ufunc loop selector
This is to allow for future expansion to multi-NA and struct-NA.
1 parent 9b4ff64 commit dbabb8e

File tree

4 files changed

+15
-1
lines changed

4 files changed

+15
-1
lines changed

numpy/core/include/numpy/ufuncobject.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ typedef int (PyUFunc_InnerLoopSelectionFunc)(
119119
typedef int (PyUFunc_MaskedInnerLoopSelectionFunc)(
120120
struct _tagPyUFuncObject *ufunc,
121121
PyArray_Descr **dtypes,
122+
PyArray_Descr *mask_dtype,
122123
npy_intp *fixed_strides,
123124
npy_intp fixed_mask_stride,
124125
PyUFunc_MaskedStridedInnerLoopFunc **out_innerloop,

numpy/core/src/umath/ufunc_object.c

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,7 @@ execute_ufunc_masked_loop(PyUFuncObject *ufunc,
16401640
PyUFunc_MaskedStridedInnerLoopFunc *innerloop;
16411641
NpyAuxData *innerloopdata;
16421642
npy_intp fixed_strides[2*NPY_MAXARGS];
1643+
PyArray_Descr **iter_dtypes;
16431644

16441645
/* Validate that the prepare_ufunc_output didn't mess with pointers */
16451646
for (i = nin; i < nop; ++i) {
@@ -1657,7 +1658,10 @@ execute_ufunc_masked_loop(PyUFuncObject *ufunc,
16571658
* based on the fixed strides.
16581659
*/
16591660
NpyIter_GetInnerFixedStrideArray(iter, fixed_strides);
1661+
iter_dtypes = NpyIter_GetDescrArray(iter);
16601662
if (ufunc->masked_inner_loop_selector(ufunc, dtypes,
1663+
wheremask != NULL ? iter_dtypes[nop]
1664+
: iter_dtypes[nop + nin],
16611665
fixed_strides,
16621666
wheremask != NULL ? fixed_strides[nop]
16631667
: fixed_strides[nop + nin],
@@ -2686,7 +2690,7 @@ masked_reduce_loop(NpyIter *iter, char **dataptrs, npy_intp *strides,
26862690
dtypes[0] = iter_dtypes[0];
26872691
dtypes[1] = iter_dtypes[1];
26882692
dtypes[2] = iter_dtypes[0];
2689-
if (ufunc->masked_inner_loop_selector(ufunc, dtypes,
2693+
if (ufunc->masked_inner_loop_selector(ufunc, dtypes, iter_dtypes[2],
26902694
fixed_strides, fixed_mask_stride,
26912695
&innerloop, &innerloopdata, &needs_api) < 0) {
26922696
return -1;

numpy/core/src/umath/ufunc_type_resolution.c

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,7 @@ unmasked_ufunc_loop_as_masked(
13921392
NPY_NO_EXPORT int
13931393
PyUFunc_DefaultMaskedInnerLoopSelector(PyUFuncObject *ufunc,
13941394
PyArray_Descr **dtypes,
1395+
PyArray_Descr *mask_dtype,
13951396
npy_intp *NPY_UNUSED(fixed_strides),
13961397
npy_intp NPY_UNUSED(fixed_mask_stride),
13971398
PyUFunc_MaskedStridedInnerLoopFunc **out_innerloop,
@@ -1409,6 +1410,13 @@ PyUFunc_DefaultMaskedInnerLoopSelector(PyUFuncObject *ufunc,
14091410
return -1;
14101411
}
14111412

1413+
if (mask_dtype->type_num != NPY_BOOL) {
1414+
PyErr_SetString(PyExc_ValueError,
1415+
"only boolean masks are supported in ufunc inner loops "
1416+
"presently");
1417+
return -1;
1418+
}
1419+
14121420
/* Create a new NpyAuxData object for the masker data */
14131421
data = (_ufunc_masker_data *)PyArray_malloc(sizeof(_ufunc_masker_data));
14141422
if (data == NULL) {

numpy/core/src/umath/ufunc_type_resolution.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ PyUFunc_DefaultLegacyInnerLoopSelector(PyUFuncObject *ufunc,
102102
NPY_NO_EXPORT int
103103
PyUFunc_DefaultMaskedInnerLoopSelector(PyUFuncObject *ufunc,
104104
PyArray_Descr **dtypes,
105+
PyArray_Descr *mask_dtypes,
105106
npy_intp *NPY_UNUSED(fixed_strides),
106107
npy_intp NPY_UNUSED(fixed_mask_stride),
107108
PyUFunc_MaskedStridedInnerLoopFunc **out_innerloop,

0 commit comments

Comments
 (0)
0