8000 ENH: ufunc: Rewrite PyUFunc_Reduce to be more general and easier to a… · numpy/numpy@aed9925 · GitHub
[go: up one dir, main page]

Skip to content

Commit aed9925

Browse files
Mark Wiebecharris
authored andcommitted
ENH: ufunc: Rewrite PyUFunc_Reduce to be more general and easier to adapt to NA masks
This generalizes the 'axis' parameter to accept None or a list of axes on which to do the reduction.
1 parent 2f0bb5d commit aed9925

File tree

12 files changed

+600
-499
lines changed

12 files changed

+600
-499
lines changed

numpy/core/code_generators/generate_umath.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def english_upper(s):
246246
TD(O, f='PyNumber_Add'),
247247
),
248248
'subtract' :
249-
Ufunc(2, 1, Zero,
249+
Ufunc(2, 1, None, # Zero is only a unit to the right, not the left
250250
docstrings.get('numpy.core.umath.subtract'),
251251
'PyUFunc_SubtractionTypeResolution',
252252
TD(notimes_or_obj),
@@ -269,7 +269,7 @@ def english_upper(s):
269269
TD(O, f='PyNumber_Multiply'),
270270
),
271271
'divide' :
272-
Ufunc(2, 1, One,
272+
Ufunc(2, 1, None, # One is only a unit to the right, not the left
273273
docstrings.get('numpy.core.umath.divide'),
274274
'PyUFunc_DivisionTypeResolution',
275275
TD(intfltcmplx),
@@ -280,7 +280,7 @@ def english_upper(s):
280280
TD(O, f='PyNumber_Divide'),
281281
),
282282
'floor_divide' :
283-
Ufunc(2, 1, One,
283+
Ufunc(2, 1, None, # One is only a unit to the right, not the left
284284
docstrings.get('numpy.core.umath.floor_divide'),
285285
'PyUFunc_DivisionTypeResolution',
286286
TD(intfltcmplx),
@@ -290,7 +290,7 @@ def english_upper(s):
290290
TD(O, f='PyNumber_FloorDivide'),
291291
),
292292
'true_divide' :
293-
Ufunc(2, 1, One,
293+
Ufunc(2, 1, None, # One is only a unit to the right, not the left
294294
docstrings.get('numpy.core.umath.true_divide'),
295295
'PyUFunc_DivisionTypeResolution',
296296
TD('bBhH', out='d'),
@@ -309,7 +309,7 @@ def english_upper(s):
309309
TD(P, f='conjugate'),
310310
),
311311
'fmod' :
312-
Ufunc(2, 1, Zero,
312+
Ufunc(2, 1, None,
313313
docstrings.get('numpy.core.umath.fmod'),
314314
None,
315315
TD(ints),
@@ -338,7 +338,7 @@ def english_upper(s):
338338
TD(O, f='Py_get_one'),
339339
),
340340
'power' :
341-
Ufunc(2, 1, One,
341+
Ufunc(2, 1, None,
342342
docstrings.get('numpy.core.umath.power'),
343343
None,
344344
TD(ints),

numpy/core/code_generators/numpy_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,9 @@
327327
'NpyIter_GetFirstMaskNAOp': 288,
328328
'NpyIter_GetMaskNAIndexArray': 289,
329329
'PyArray_ReduceMaskArray': 290,
330+
'PyArray_CreateSortedStridePerm': 291,
331+
'PyArray_FillWithZero': 292,
332+
'PyArray_FillWithOne': 293,
330333
}
331334

332335
ufunc_types_api = {

numpy/core/include/numpy/halffloat.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ npy_half npy_half_nextafter(npy_half x, npy_half y);
5252
#define NPY_HALF_NINF (0xfc00u)
5353
#define NPY_HALF_NAN (0x7e00u)
5454

55+
#define NPY_MAX_HALF (0x7bffu)
56+
5557
/*
5658
* Bit-level conversions
5759
*/

numpy/core/include/numpy/ndarraytypes.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1759,6 +1759,14 @@ struct NpyAuxData_tag {
17591759
#define NPY_AUXDATA_CLONE(auxdata) \
17601760
((auxdata)->clone(auxdata))
17611761

1762+
/************************************************************
1763+
* A struct used by PyArray_CreateSortedStridePerm
1764+
************************************************************/
1765+
1766+
typedef struct {
1767+
npy_intp perm, stride;
1768+
} npy_stride_sort_item;
1769+
17621770
/*
17631771
* This is the form of the struct that's returned pointed by the
17641772
* PyCObject attribute of an array __array_struct__. See

numpy/core/src/multiarray/convert.c

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ PyArray_FillWithScalar(PyArrayObject *arr, PyObject *obj)
368368
return 0;
369369
}
370370

371-
/*
371+
/*NUMPY_API
372+
*
372373
* Fills an array with zeros.
373374
*
374375
* Returns 0 on success, -1 on failure.
@@ -398,9 +399,32 @@ PyArray_FillWithZero(PyArrayObject *a)
398399
return 0;
399400
}
400401

402+
/* If there's an object type, copy the value zero to everything */
403+
if (PyDataType_REFCHK(dtype)) {
404+
PyArrayObject *tmp;
405+
PyArray_Descr *bool_dtype = PyArray_DescrFromType(NPY_BOOL);
406+
int retcode;
407+
408+
/* Create a boolean array with 0 in it */
409+
if (dtype == NULL) {
410+
return -1;
411+
}
412+
tmp = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
413+
bool_dtype, 0, NULL, NULL,
414+
NULL, 0, NULL);
415+
if (tmp == NULL) {
416+
return -1;
417+
}
418+
PyArray_BYTES(tmp)[0] = 0;
419+
420+
retcode = PyArray_CopyInto(a, tmp);
421+
Py_DECREF(tmp);
422+
423+
return retcode;
424+
}
425+
401426
/* If it's possible to do a simple memset, do so */
402-
if (!PyDataType_REFCHK(dtype) && (PyArray_ISCONTIGUOUS(a) ||
403-
PyArray_ISFORTRAN(a))) {
427+
if (PyArray_IS_C_CONTIGUOUS(a) || PyArray_IS_F_CONTIGUOUS(a)) {
404428
memset(PyArray_DATA(a), 0, PyArray_NBYTES(a));
405429
return 0;
406430
}
@@ -463,6 +487,48 @@ PyArray_FillWithZero(PyArrayObject *a)
463487
return 0;
464488
}
465489

490+
/*NUMPY_API
491+
*
492+
* Fills an array with ones.
493+
*
494+
* Returns 0 on success, -1 on failure.
495+
*/
496+
NPY_NO_EXPORT int
497+
PyArray_FillWithOne(PyArrayObject *a)
498+
{
499+
PyArrayObject *tmp;
500+
PyArray_Descr *bool_dtype;
501+
int retcode;
502+
503+
if (!PyArray_ISWRITEABLE(a)) {
504+
PyErr_SetString(PyExc_RuntimeError, "cannot write to array");
505+
return -1;
506+
}
507+
508+
/* A zero-sized array needs no zeroing */
509+
if (PyArray_SIZE(a) == 0) {
510+
return 0;
511+
}
512+
513+
/* Create a boolean array with 0 in it */
514+
bool_dtype = PyArray_DescrFromType(NPY_BOOL);
515+
if (bool_dtype == NULL) {
516+
return -1;
517+
}
518+
tmp = (PyArrayObject *)PyArray_NewFromDescr(&PyArray_Type,
519+
bool_dtype, 0, NULL, NULL,
520+
NULL, 0, NULL);
521+
if (tmp == NULL) {
522+
return -1;
523+
}
524+
PyArray_BYTES(tmp)[0] = 1;
525+
526+
retcode = PyArray_CopyInto(a, tmp);
527+
Py_DECREF(tmp);
528+
529+
return retcode 7802 ;
530+
}
531+
466532
/*NUMPY_API
467533
* Copy an array.
468534
*/
@@ -534,6 +600,7 @@ PyArray_View(PyArrayObject *self, PyArray_Descr *type, PyTypeObject *pytype)
534600
PyErr_SetString(PyExc_RuntimeError,
535601
"NA masks with fields are not supported yet");
536602
Py_DECREF(ret);
603+
Py_DECREF(type);
537604
return NULL;
538605
}
539606

numpy/core/src/multiarray/convert.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@
44
NPY_NO_EXPORT int
55
PyArray_FillWithZero(PyArrayObject *a);
66

7+
NPY_NO_EXPORT int
8+
PyArray_FillWithOne(PyArrayObject *a);
9+
710
#endif

numpy/core/src/multiarray/ctors.c

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,23 +1391,23 @@ PyArray_NewLikeArray(PyArrayObject *prototype, NPY_ORDER order,
13911391
else {
13921392
npy_intp strides[NPY_MAXDIMS], stride;
13931393
npy_intp *shape = PyArray_DIMS(prototype);
1394-
_npy_stride_sort_item strideperm[NPY_MAXDIMS];
1395-
int i;
1394+
npy_stride_sort_item strideperm[NPY_MAXDIMS];
1395+
int idim;
13961396

13971397
PyArray_CreateSortedStridePerm(PyArray_NDIM(prototype),
13981398
PyArray_STRIDES(prototype),
13991399
strideperm);
14001400

14011401
/* Build the new strides */
14021402
stride = dtype->elsize;
1403-
for (i = ndim-1; i >= 0; --i) {
1404-
npy_intp i_perm = strideperm[i].perm;
1403+
for (idim = ndim-1; idim >= 0; --idim) {
1404+
npy_intp i_perm = strideperm[idim].perm;
14051405
strides[i_perm] = stride;
14061406
stride *= shape[i_perm];
14071407
}
14081408

14091409
/* Finally, allocate the array */
1410-
ret = PyArray_NewFromDescr( subok ? Py_TYPE(prototype) : &PyArray_Type,
1410+
ret = PyArray_NewFromDescr(subok ? Py_TYPE(prototype) : &PyArray_Type,
14111411
dtype,
14121412
ndim,
14131413
shape,

numpy/core/src/multiarray/dtype_transfer.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3894,7 +3894,7 @@ PyArray_PrepareOneRawArrayIter(int ndim, npy_intp *shape,
38943894
int *out_ndim, npy_intp *out_shape,
38953895
char **out_data, npy_intp *out_strides)
38963896
{
3897-
_npy_stride_sort_item strideperm[NPY_MAXDIMS];
3897+
npy_stride_sort_item strideperm[NPY_MAXDIMS];
38983898
int i, j;
38993899

39003900
/* Special case 0 and 1 dimensions */
@@ -3998,7 +3998,7 @@ PyArray_PrepareTwoRawArrayIter(int ndim, npy_intp *shape,
39983998
char **out_dataA, npy_intp *out_stridesA,
39993999
char **out_dataB, npy_intp *out_stridesB)
40004000
{
4001-
_npy_stride_sort_item strideperm[NPY_MAXDIMS];
4001+
npy_stride_sort_item strideperm[NPY_MAXDIMS];
40024002
int i, j;
40034003

40044004
/* Special case 0 and 1 dimensions */

numpy/core/src/multiarray/na_mask.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ PyArray_AllocateMaskNA(PyArrayObject *arr,
240240
fa->maskna_strides[0] = maskna_dtype->elsize;
241241
}
242242
else if (fa->nd > 1) {
243-
_npy_stride_sort_item strideperm[NPY_MAXDIMS];
243+
npy_stride_sort_item strideperm[NPY_MAXDIMS];
244244
npy_intp stride, maskna_strides[NPY_MAXDIMS], *shape;
245245
int i;
246246

numpy/core/src/multiarray/shape.c

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -779,8 +779,8 @@ PyArray_Transpose(PyArrayObject *ap, PyArray_Dims *permute)
779779
*/
780780
int _npy_stride_sort_item_comparator(const void *a, const void *b)
781781
{
782-
npy_intp astride = ((_npy_stride_sort_item *)a)->stride,
783-
bstride = ((_npy_stride_sort_item *)b)->stride;
782+
npy_intp astride = ((npy_stride_sort_item *)a)->stride,
783+
bstride = ((npy_stride_sort_item *)b)->stride;
784784

785785
/* Sort the absolute value of the strides */
786786
if (astride < 0) {
@@ -798,24 +798,25 @@ int _npy_stride_sort_item_comparator(const void *a, const void *b)
798798
* Make the qsort stable by next comparing the perm order.
799799
* (Note that two perm entries will never be equal)
800800
*/
801-
npy_intp aperm = ((_npy_stride_sort_item *)a)->perm,
802-
bperm = ((_npy_stride_sort_item *)b)->perm;
801+
npy_intp aperm = ((npy_stride_sort_item *)a)->perm,
802+
bperm = ((npy_stride_sort_item *)b)->perm;
803803
return (aperm < bperm) ? -1 : 1;
804804
}
805805
else {
806806
return 1;
807807
}
808808
}
809809

810-
/*
810+
/*NUMPY_API
811+
*
811812
* This function populates the first ndim elements
812813
* of strideperm with sorted descending by their absolute values.
813814
* For example, the stride array (4, -2, 12) becomes
814815
* [(2, 12), (0, 4), (1, -2)].
815816
*/
816817
NPY_NO_EXPORT void
817818
PyArray_CreateSortedStridePerm(int ndim, npy_intp * strides,
818-
_npy_stride_sort_item *strideperm)
819+
npy_stride_sort_item *strideperm)
819820
{
820821
int i;
821822

@@ -826,7 +827,7 @@ PyArray_CreateSortedStridePerm(int ndim, npy_intp * strides,
826827
}
827828

828829
/* Sort them */
829-
qsort(strideperm, ndim, sizeof(_npy_stride_sort_item),
830+
qsort(strideperm, ndim, sizeof(npy_stride_sort_item),
830831
&_npy_stride_sort_item_comparator);
831832
}
832833

@@ -862,7 +863,7 @@ PyArray_Ravel(PyArrayObject *a, NPY_ORDER order)
862863
}
863864
/* For KEEPORDER, check if we can make a flattened view */
864865
else if (order == NPY_KEEPORDER) {
865-
_npy_stride_sort_item strideperm[NPY_MAXDIMS];
866+
npy_stride_sort_item strideperm[NPY_MAXDIMS];
866867
npy_intp stride;
867868
int i, ndim = PyArray_NDIM(a);
868869

0 commit comments

Comments
 (0)
0