8000 Merge pull request #20139 from seberg/ufunc-at-new-api · thomasjpfan/numpy@1c613f7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1c613f7

Browse files
authored
Merge pull request numpy#20139 from seberg/ufunc-at-new-api
MAINT,BUG: Fix `ufunc.at` to use new ufunc API
2 parents 07447fd + ac214d2 commit 1c613f7

File tree

5 files changed

+116
-39
lines changed

5 files changed

+116
-39
lines changed

numpy/core/src/multiarray/dtypemeta.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ typedef struct {
7474
#define NPY_DTYPE(descr) ((PyArray_DTypeMeta *)Py_TYPE(descr))
7575
#define NPY_DT_SLOTS(dtype) ((NPY_DType_Slots *)(dtype)->dt_slots)
7676

77-
#define NPY_DT_is_legacy(dtype) ((dtype)->flags & NPY_DT_LEGACY)
78-
#define NPY_DT_is_abstract(dtype) ((dtype)->flags & NPY_DT_ABSTRACT)
79-
#define NPY_DT_is_parametric(dtype) ((dtype)->flags & NPY_DT_PARAMETRIC)
77+
#define NPY_DT_is_legacy(dtype) (((dtype)->flags & NPY_DT_LEGACY) != 0)
78+
#define NPY_DT_is_abstract(dtype) (((dtype)->flags & NPY_DT_ABSTRACT) != 0)
79+
#define NPY_DT_is_parametric(dtype) (((dtype)->flags & NPY_DT_PARAMETRIC) != 0)
8080

8181
/*
8282
* Macros for convenient classmethod calls, since these require

numpy/core/src/umath/dispatching.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ resolve_implementation_info(PyUFuncObject *ufunc,
193193
/* Unspecified out always matches (see below for inputs) */
194194
continue;
195195
}
196+
if (resolver_dtype == (PyArray_DTypeMeta *)Py_None) {
197+
/* always matches */
198+
continue;
199+
}
196200
if (given_dtype == resolver_dtype) {
197201
continue;
198202 10000
}

numpy/core/src/umath/ufunc_object.c

Lines changed: 86 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5880,15 +5880,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
58805880
PyArrayObject *op2_array = NULL;
58815881
PyArrayMapIterObject *iter = NULL;
58825882
PyArrayIterObject *iter2 = NULL;
5883-
PyArray_Descr *dtypes[3] = {NULL, NULL, NULL};
58845883
PyArrayObject *operands[3] = {NULL, NULL, NULL};
58855884
PyArrayObject *array_operands[3] = {NULL, NULL, NULL};
58865885

5887-
int needs_api = 0;
5886+
PyArray_DTypeMeta *signature[3] = {NULL, NULL, NULL};
5887+
PyArray_DTypeMeta *operand_DTypes[3] = {NULL, NULL, NULL};
5888+
PyArray_Descr *operation_descrs[3] = {NULL, NULL, NULL};
58885889

5889-
PyUFuncGenericFunction innerloop;
5890-
void *innerloopdata;
5891-
npy_intp i;
58925890
int nop;
58935891

58945892
/* override vars */
@@ -5901,6 +5899,10 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
59015899
int buffersize;
59025900
int errormask = 0;
59035901
char * err_msg = NULL;
5902+
5903+
PyArrayMethod_StridedLoop *strided_loop;
5904+
NpyAuxData *auxdata = NULL;
5905+
59045906
NPY_BEGIN_THREADS_DEF;
59055907

59065908
if (ufunc->nin > 2) {
@@ -5988,26 +5990,51 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
59885990

59895991
/*
59905992
* Create dtypes array for either one or two input operands.
5991-
* The output operand is set to the first input operand
5993+
* Compare to the logic in `convert_ufunc_arguments`.
5994+
* TODO: It may be good to review some of this behaviour, since the
5995+
* operand array is special (it is written to) similar to reductions.
5996+
* Using unsafe-casting as done here, is likely not desirable.
59925997
*/
59935998
operands[0] = op1_array;
5999+
operand_DTypes[0] = NPY_DTYPE(PyArray_DESCR(op1_array));
6000+
Py_INCREF(operand_DTypes[0]);
6001+
int force_legacy_promotion = 0;
6002+
int allow_legacy_promotion = NPY_DT_is_legacy(operand_DTypes[0]);
6003+
59946004
if (op2_array != NULL) {
59956005
operands[1] = op2_array;
5996-
operands[2] = op1_array;
6006+
operand_DTypes[1] = NPY_DTYPE(PyArray_DESCR(op2_array));
6007+
Py_INCREF(operand_DTypes[1]);
6008+
allow_legacy_promotion &= NPY_DT_is_legacy(operand_DTypes[1]);
6009+
operands[2] = operands[0];
6010+
operand_DTypes[2] = operand_DTypes[0];
6011+
Py_INCREF(operand_DTypes[2]);
6012+
59976013
nop = 3;
6014+
if (allow_legacy_promotion && ((PyArray_NDIM(op1_array) == 0)
6015+
!= (PyArray_NDIM(op2_array) == 0))) {
6016+
/* both are legacy and only one is 0-D: force legacy */
6017+
force_legacy_promotion = should_use_min_scalar(2, operands, 0, NULL);
6018+
}
59986019
}
59996020
else {
6000-
operands[1] = op1_array;
6021+
operands[1] = operands[0];
6022+
operand_DTypes[1] = operand_DTypes[0];
6023+
Py_INCREF(operand_DTypes[1]);
60016024
operands[2] = NULL;
60026025
nop = 2;
60036026
}
60046027

6005-
if (ufunc->type_resolver(ufunc, NPY_UNSAFE_CASTING,
6006-
operands, NULL, dtypes) < 0) {
6028+
PyArrayMethodObject *ufuncimpl = promote_and_get_ufuncimpl(ufunc,
6029+
operands, signature, operand_DTypes,
6030+
force_legacy_promotion, allow_legacy_promotion);
6031+
if (ufuncimpl == NULL) {
60076032
goto fail;
60086033
}
6009-
if (ufunc->legacy_inner_loop_selector(ufunc, dtypes,
6010-
&innerloop, &innerloopdata, &needs_api) < 0) {
6034+
6035+
/* Find the correct descriptors for the operation */
6036+
if (resolve_descriptors(nop, ufunc, ufuncimpl,
6037+
operands, operation_descrs, signature, NPY_UNSAFE_CASTING) < 0) {
60116038
goto fail;
60126039
}
60136040

@@ -6068,21 +6095,44 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60686095
NPY_ITER_GROWINNER|
60696096
NPY_ITER_DELAY_BUFALLOC,
60706097
NPY_KEEPORDER, NPY_UNSAFE_CASTING,
6071-
op_flags, dtypes,
6098+
op_flags, operation_descrs,
60726099
-1, NULL, NULL, buffersize);
60736100

60746101
if (iter_buffer == NULL) {
60756102
goto fail;
60766103
}
60776104

6078-
needs_api = needs_api | NpyIter_IterationNeedsAPI(iter_buffer);
6079-
60806105
iternext = NpyIter_GetIterNext(iter_buffer, NULL);
60816106
if (iternext == NULL) {
60826107
NpyIter_Deallocate(iter_buffer);
60836108
goto fail;
60846109
}
60856110

6111+
PyArrayMethod_Context context = {
6112+
.caller = (PyObject *)ufunc,
6113+
.method = ufuncimpl,
6114+
.descriptors = operation_descrs,
6115+
};
6116+
6117+
NPY_ARRAYMETHOD_FLAGS flags;
6118+
/* Use contiguous strides; if there is such a loop it may be faster */
6119+
npy_intp strides[3] = {
6120+
operation_descrs[0]->elsize, operation_descrs[1]->elsize, 0};
6121+
if (nop == 3) {
6122+
strides[2] = operation_descrs[2]->elsize;
6123+
}
6124+
6125+
if (ufuncimpl->get_strided_loop(&context, 1, 0, strides,
6126+
&strided_loop, &auxdata, &flags) < 0) {
6127+
goto fail;
6128+
}
6129+
int needs_api = (flags & NPY_METH_REQUIRES_PYAPI) != 0;
6130+
needs_api |= NpyIter_IterationNeedsAPI(iter_buffer);
6131+
if (!(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS)) {
6132+
/* Start with the floating-point exception flags cleared */
6133+
npy_clear_floatstatus_barrier((char*)&iter);
6134+
}
6135+
60866136
if (!needs_api) {
60876137
NPY_BEGIN_THREADS;
60886138
}
@@ -6091,14 +6141,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
60916141
* Iterate over first and second operands and call ufunc
60926142
* for each pair of inputs
60936143
*/
6094-
i = iter->size;
6095-
while (i > 0)
6144+
int res = 0;
6145+
for (npy_intp i = iter->size; i > 0; i--)
60966146
{
60976147
char *dataptr[3];
60986148
char **buffer_dataptr;
60996149
/* one element at a time, no stride required but read by innerloop */
6100-
npy_intp count[3] = {1, 0xDEADBEEF, 0xDEADBEEF};
6101-
npy_intp stride[3] = {0xDEADBEEF, 0xDEADBEEF, 0xDEADBEEF};
6150+
npy_intp count = 1;
61026151

61036152
/*
61046153
* Set up data pointers for either one or two input operands.
@@ -6117,14 +6166,14 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61176166
/* Reset NpyIter data pointers which will trigger a buffer copy */
61186167
NpyIter_ResetBasePointers(iter_buffer, dataptr, &err_msg);
61196168
if (err_msg) {
6169+
res = -1;
61206170
break;
61216171
}
61226172

61236173
buffer_dataptr = NpyIter_GetDataPtrArray(iter_buffer);
61246174

6125-
innerloop(buffer_dataptr, count, stride, innerloopdata);
6126-
6127-
if (needs_api && PyErr_Occurred()) {
6175+
res = strided_loop(&context, buffer_dataptr, &count, strides, auxdata);
6176+
if (res != 0) {
61286177
break;
61296178
}
61306179

@@ -6138,32 +6187,35 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61386187
if (iter2 != NULL) {
61396188
PyArray_ITER_NEXT(iter2);
61406189
}
6141-
6142-
i--;
61436190
}
61446191

61456192
NPY_END_THREADS;
61466193

6147-
if (err_msg) {
6194+
if (res != 0 && err_msg) {
61486195
PyErr_SetString(PyExc_ValueError, err_msg);
61496196
}
6197+
if (res == 0 && !(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS)) {
6198+
/* NOTE: We could check float errors even when `res < 0` */
6199+
res = _check_ufunc_fperr(errormask, NULL, "at");
6200+
}
61506201

6202+
NPY_AUXDATA_FREE(auxdata);
61516203
NpyIter_Deallocate(iter_buffer);
61526204

61536205
Py_XDECREF(op2_array);
61546206
Py_XDECREF(iter);
61556207
Py_XDECREF(iter2);
6156-
for (i = 0; i < 3; i++) {
6157-
Py_XDECREF(dtypes[i]);
6208+
for (int i = 0; i < 3; i++) {
6209+
Py_XDECREF(operation_descrs[i]);
61586210
Py_XDECREF(array_operands[i]);
61596211
}
61606212

61616213
/*
6162-
* An error should only be possible if needs_api is true, but this is not
6163-
* strictly correct for old-style ufuncs (e.g. `power` released the GIL
6164-
* but manually set an Exception).
6214+
* An error should only be possible if needs_api is true or `res != 0`,
6215+
* but this is not strictly correct for old-style ufuncs
6216+
* (e.g. `power` released the GIL but manually set an Exception).
61656217
*/
6166-
if (PyErr_Occurred()) {
6218+
if (res != 0 || PyErr_Occurred()) {
61676219
return NULL;
61686220
}
61696221
else {
@@ -6178,10 +6230,11 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
61786230
Py_XDECREF(op2_array);
61796231
Py_XDECREF(iter);
61806232
Py_XDECREF(iter2);
6181-
for (i = 0; i < 3; i++) {
6182-
Py_XDECREF(dtypes[i]);
6233+
for (int i = 0; i < 3; i++) {
6234+
Py_XDECREF(operation_descrs[i]);
61836235
Py_XDECREF(array_operands[i]);
61846236
}
6237+
NPY_AUXDATA_FREE(auxdata);
61856238

61866239
return NULL;
61876240
}

numpy/core/tests/test_custom_dtypes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,18 +117,36 @@ def test_possible_and_impossible_reduce(self):
117117
match="the resolved dtypes are not compatible"):
118118
np.multiply.reduce(a)
119119

120+
def test_basic_ufunc_at(self):
121+
float_a = np.array([1., 2., 3.])
122+
b = self._get_array(2.)
123+
124+
float_b = b.view(np.float64).copy()
125+
np.multiply.at(float_b, [1, 1, 1], float_a)
126+
np.multiply.at(b, [1, 1, 1], float_a)
127+
128+
assert_array_equal(b.view(np.float64), float_b)
129+
120130
def test_basic_multiply_promotion(self):
121131
float_a = np.array([1., 2., 3.])
122132
b = self._get_array(2.)
123133

124134
res1 = float_a * b
125135
res2 = b * float_a
136+
126137
# one factor is one, so we get the factor of b:
127138
assert res1.dtype == res2.dtype == b.dtype
128139
expected_view = float_a * b.view(np.float64)
129140
assert_array_equal(res1.view(np.float64), expected_view)
130141
assert_array_equal(res2.view(np.float64), expected_view)
131142

143+
# Check that promotion works when `out` is used:
144+
np.multiply(b, float_a, out=res2)
145+
with pytest.raises(TypeError):
146+
# The promoter accepts this (maybe it should not), but the SFloat
147+
# result cannot be cast to integer:
148+
np.multiply(b, float_a, out=np.arange(3))
149+
132150
def test_basic_addition(self):
133151
a = self._get_array(2.)
134152
b = self._get_array(4.)

numpy/core/tests/test_ufunc.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2397,14 +2397,16 @@ def test_reduce_casterrors(offset):
23972397

23982398
@pytest.mark.parametrize("method",
23992399
[np.add.accumulate, np.add.reduce,
2400-
pytest.param(lambda x: np.add.reduceat(x, [0]), id="reduceat")])
2401-
def test_reducelike_floaterrors(method):
2402-
# adding inf and -inf creates an invalid float and should give a warning
2400+
pytest.param(lambda x: np.add.reduceat(x, [0]), id="reduceat"),
2401+
pytest.param(lambda x: np.log.at(x, [2]), id="at")])
2402+
def test_ufunc_methods_floaterrors(method):
2403+
# adding inf and -inf (or log(-inf) creates an invalid float and warns
24032404
arr = np.array([np.inf, 0, -np.inf])
24042405
with np.errstate(all="warn"):
24052406
with pytest.warns(RuntimeWarning, match="invalid value"):
24062407
method(arr)
24072408

2409+
arr = np.array([np.inf, 0, -np.inf])
24082410
with np.errstate(all="raise"):
24092411
with pytest.raises(FloatingPointError):
24102412
method(arr)

0 commit comments

Comments
 (0)
0