8000 API,MAINT: Make ``NpyIter_GetTransferFlags`` public and avoid old uses by seberg · Pull Request #27998 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

API,MAINT: Make NpyIter_GetTransferFlags public and avoid old uses #27998

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
MAINT: Remove/move iterator "needs api" and avoid it in reductions
  • Loading branch information
seberg committed Dec 17, 2024
commit 201c0b5a9686482674897638740bd1895acc51a5
20 changes: 17 additions & 3 deletions numpy/_core/src/multiarray/nditer_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,23 @@ NpyIter_RequiresBuffering(NpyIter *iter)
NPY_NO_EXPORT npy_bool
NpyIter_IterationNeedsAPI(NpyIter *iter)
{
return (NIT_ITFLAGS(iter)&NPY_ITFLAG_NEEDSAPI) != 0;
int nop = NIT_NOP(iter);
/* If any of the buffer filling need the API, flag it as well. */
if (NpyIter_GetTransferFlags(iter) & NPY_METH_REQUIRES_PYAPI) {
return NPY_TRUE;
}

for (int iop = 0; iop < nop; ++iop) {
PyArray_Descr *rdt = NIT_DTYPES(iter)[iop];
if ((rdt->flags & (NPY_ITEM_REFCOUNT |
NPY_ITEM_IS_POINTER |
NPY_NEEDS_PYAPI)) != 0) {
/* Iteration needs API access */
return NPY_TRUE;
}
}

return NPY_FALSE;
}


Expand Down Expand Up @@ -1420,8 +1436,6 @@ NpyIter_DebugPrint(NpyIter *iter)
printf("ONEITERATION ");
if (itflags&NPY_ITFLAG_DELAYBUF)
printf("DELAYBUF ");
if (itflags&NPY_ITFLAG_NEEDSAPI)
printf("NEEDSAPI ");
if (itflags&NPY_ITFLAG_REDUCE)
printf("REDUCE ");
if (itflags&NPY_ITFLAG_REUSE_REDUCE_LOOPS)
Expand Down
26 changes: 0 additions & 26 deletions numpy/_core/src/multiarray/nditer_constr.c
Original file line number Diff line number Diff line change
Expand Up @@ -432,27 +432,6 @@ NpyIter_AdvancedNew(int nop, PyArrayObject **op_in, npy_uint32 flags,
}
}

/*
* If REFS_OK was specified, check whether there are any
* reference arrays and flag it if so.
*
* NOTE: This really should be unnecessary, but chances are someone relies
* on it. The iterator itself does not require the API here
* as it only does so for casting/buffering. But in almost all
* use-cases the API will be required for whatever operation is done.
*/
if (flags & NPY_ITER_REFS_OK) {
for (iop = 0; iop < nop; ++iop) {
PyArray_Descr *rdt = op_dtype[iop];
if ((rdt->flags & (NPY_ITEM_REFCOUNT |
NPY_ITEM_IS_POINTER |
NPY_NEEDS_PYAPI)) != 0) {
/* Iteration needs API access */
NIT_ITFLAGS(iter) |= NPY_ITFLAG_NEEDSAPI;
}
}
}

/* If buffering is set prepare it */
if (itflags & NPY_ITFLAG_BUFFER) {
npyiter_find_buffering_setup(iter, buffersize);
Expand Down Expand Up @@ -3566,11 +3545,6 @@ npyiter_allocate_transfer_functions(NpyIter *iter)
NIT_ITFLAGS(iter) |= cflags << NPY_ITFLAG_TRANSFERFLAGS_SHIFT;
assert(NIT_ITFLAGS(iter) >> NPY_ITFLAG_TRANSFERFLAGS_SHIFT == cflags);

/* If any of the dtype transfer functions needed the API, flag it. */
if (cflags & NPY_METH_REQUIRES_PYAPI) {
NIT_ITFLAGS(iter) |= NPY_ITFLAG_NEEDSAPI;
}

return 1;

fail:
Expand Down
6 changes: 2 additions & 4 deletions numpy/_core/src/multiarray/nditer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,10 @@
#define NPY_ITFLAG_ONEITERATION (1 << 9)
/* Delay buffer allocation until first Reset* call */
#define NPY_ITFLAG_DELAYBUF (1 << 10)
/* Iteration needs API access during iternext */
#define NPY_ITFLAG_NEEDSAPI (1 << 11)
/* Iteration includes one or more operands being reduced */
#define NPY_ITFLAG_REDUCE (1 << 12)
#define NPY_ITFLAG_REDUCE (1 << 11)
/* Reduce iteration doesn't need to recalculate reduce loops next time */
#define NPY_ITFLAG_REUSE_REDUCE_LOOPS (1 << 13)
#define NPY_ITFLAG_REUSE_REDUCE_LOOPS (1 << 12)
/*
* Offset of (combined) ArrayMethod flags for all transfer functions.
* For now, we use the top 8 bits.
Expand Down
41 changes: 20 additions & 21 deletions numpy/_core/src/umath/reduction.c
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
#include "lowlevel_strided_loops.h"
#include "reduction.h"
#include "extobj.h" /* for _check_ufunc_fperr */
/* TODO: Only for `NpyIter_GetTransferFlags` until it is public */
#define NPY_ITERATOR_IMPLEMENTATION_CODE
#include "nditer_impl.h"


/*
Expand Down Expand Up @@ -339,10 +342,25 @@ PyUFunc_ReduceWrapper(PyArrayMethod_Context *context,
}

PyArrayMethod_StridedLoop *strided_loop;
NPY_ARRAYMETHOD_FLAGS flags = 0;
NPY_ARRAYMETHOD_FLAGS flags;

npy_intp fixed_strides[3];
NpyIter_GetInnerFixedStrideArray(iter, fixed_strides);
if (wheremask != NULL) {
if (PyArrayMethod_GetMaskedStridedLoop(context,
1, fixed_strides, &strided_loop, &auxdata, &flags) < 0) {
goto fail;
}
}
else {
if (context->method->get_strided_loop(context,
1, 0, fixed_strides, &strided_loop, &auxdata, &flags) < 0) {
goto fail;
}
}
flags = PyArrayMethod_COMBINED_FLAGS(flags, NpyIter_GetTransferFlags(iter));

int needs_api = (flags & NPY_METH_REQUIRES_PYAPI) != 0;
needs_api |= NpyIter_IterationNeedsAPI(iter);
if (!(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS)) {
/* Start with the floating-point exception flags cleared */
npy_clear_floatstatus_barrier((char*)&iter);
Expand Down Expand Up @@ -389,25 +407,6 @@ PyUFunc_ReduceWrapper(PyArrayMethod_Context *context,
goto fail;
}

/*
* Note that we need to ensure that the iterator is reset before getting
* the fixed strides. (The buffer information is uninitialized before.)
*/
npy_intp fixed_strides[3];
NpyIter_GetInnerFixedStrideArray(iter, fixed_strides);
if (wheremask != NULL) {
if (PyArrayMethod_GetMaskedStridedLoop(context,
1, fixed_strides, &strided_loop, &auxdata, &flags) < 0) {
goto fail;
}
}
else {
if (context->method->get_strided_loop(context,
1, 0, fixed_strides, &strided_loop, &auxdata, &flags) < 0) {
goto fail;
}
}

if (!empty_iteration) {
NpyIter_IterNextFunc *iternext;
char **dataptr;
Expand Down
0