@@ -5880,15 +5880,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
5880
5880
PyArrayObject * op2_array = NULL ;
5881
5881
PyArrayMapIterObject * iter = NULL ;
5882
5882
PyArrayIterObject * iter2 = NULL ;
5883
- PyArray_Descr * dtypes [3 ] = {NULL , NULL , NULL };
5884
5883
PyArrayObject * operands [3 ] = {NULL , NULL , NULL };
5885
5884
PyArrayObject * array_operands [3 ] = {NULL , NULL , NULL };
5886
5885
5887
- int needs_api = 0 ;
5886
+ PyArray_DTypeMeta * signature [3 ] = {NULL , NULL , NULL };
5887
+ PyArray_DTypeMeta * operand_DTypes [3 ] = {NULL , NULL , NULL };
5888
+ PyArray_Descr * operation_descrs [3 ] = {NULL , NULL , NULL };
5888
5889
5889
- PyUFuncGenericFunction innerloop ;
5890
- void * innerloopdata ;
5891
- npy_intp i ;
5892
5890
int nop ;
5893
5891
5894
5892
/* override vars */
@@ -5901,6 +5899,10 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
5901
5899
int buffersize ;
5902
5900
int errormask = 0 ;
5903
5901
char * err_msg = NULL ;
5902
+
5903
+ PyArrayMethod_StridedLoop * strided_loop ;
5904
+ NpyAuxData * auxdata = NULL ;
5905
+
5904
5906
NPY_BEGIN_THREADS_DEF ;
5905
5907
5906
5908
if (ufunc -> nin > 2 ) {
@@ -5988,26 +5990,51 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
5988
5990
5989
5991
/*
5990
5992
* Create dtypes array for either one or two input operands.
5991
- * The output operand is set to the first input operand
5993
+ * Compare to the logic in `convert_ufunc_arguments`.
5994
+ * TODO: It may be good to review some of this behaviour, since the
5995
+ * operand array is special (it is written to) similar to reductions.
5996
+ * Using unsafe-casting as done here, is likely not desirable.
5992
5997
*/
5993
5998
operands [0 ] = op1_array ;
5999
+ operand_DTypes [0 ] = NPY_DTYPE (PyArray_DESCR (op1_array ));
6000
+ Py_INCREF (operand_DTypes [0 ]);
6001
+ int force_legacy_promotion = 0 ;
6002
+ int allow_legacy_promotion = NPY_DT_is_legacy (operand_DTypes [0 ]);
6003
+
5994
6004
if (op2_array != NULL ) {
5995
6005
operands [1 ] = op2_array ;
5996
- operands [2 ] = op1_array ;
6006
+ operand_DTypes [1 ] = NPY_DTYPE (PyArray_DESCR (op2_array ));
6007
+ Py_INCREF (operand_DTypes [1 ]);
6008
+ allow_legacy_promotion &= NPY_DT_is_legacy (operand_DTypes [1 ]);
6009
+ operands [2 ] = operands [0 ];
6010
+ operand_DTypes [2 ] = operand_DTypes [0 ];
6011
+ Py_INCREF (operand_DTypes [2 ]);
6012
+
5997
6013
nop = 3 ;
6014
+ if (allow_legacy_promotion && ((PyArray_NDIM (op1_array ) == 0 )
6015
+ != (PyArray_NDIM (op2_array ) == 0 ))) {
6016
+ /* both are legacy and only one is 0-D: force legacy */
6017
+ force_legacy_promotion = should_use_min_scalar (2 , operands , 0 , NULL );
6018
+ }
5998
6019
}
5999
6020
else {
6000
- operands [1 ] = op1_array ;
6021
+ operands [1 ] = operands [0 ];
6022
+ operand_DTypes [1 ] = operand_DTypes [0 ];
6023
+ Py_INCREF (operand_DTypes [1 ]);
6001
6024
operands [2 ] = NULL ;
6002
6025
nop = 2 ;
6003
6026
}
6004
6027
6005
- if (ufunc -> type_resolver (ufunc , NPY_UNSAFE_CASTING ,
6006
- operands , NULL , dtypes ) < 0 ) {
6028
+ PyArrayMethodObject * ufuncimpl = promote_and_get_ufuncimpl (ufunc ,
6029
+ operands , signature , operand_DTypes ,
6030
+ force_legacy_promotion , allow_legacy_promotion );
6031
+ if (ufuncimpl == NULL ) {
6007
6032
goto fail ;
6008
6033
}
6009
- if (ufunc -> legacy_inner_loop_selector (ufunc , dtypes ,
6010
- & innerloop , & innerloopdata , & needs_api ) < 0 ) {
6034
+
6035
+ /* Find the correct descriptors for the operation */
6036
+ if (resolve_descriptors (nop , ufunc , ufuncimpl ,
6037
+ operands , operation_descrs , signature , NPY_UNSAFE_CASTING ) < 0 ) {
6011
6038
goto fail ;
6012
6039
}
6013
6040
@@ -6068,21 +6095,44 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
6068
6095
NPY_ITER_GROWINNER |
6069
6096
NPY_ITER_DELAY_BUFALLOC ,
6070
6097
NPY_KEEPORDER , NPY_UNSAFE_CASTING ,
6071
- op_flags , dtypes ,
6098
+ op_flags , operation_descrs ,
6072
6099
-1 , NULL , NULL , buffersize );
6073
6100
6074
6101
if (iter_buffer == NULL ) {
6075
6102
goto fail ;
6076
6103
}
6077
6104
6078
- needs_api = needs_api | NpyIter_IterationNeedsAPI (iter_buffer );
6079
-
6080
6105
iternext = NpyIter_GetIterNext (iter_buffer , NULL );
6081
6106
if (iternext == NULL ) {
6082
6107
NpyIter_Deallocate (iter_buffer );
6083
6108
goto fail ;
6084
6109
}
6085
6110
6111
+ PyArrayMethod_Context context = {
6112
+ .caller = (PyObject * )ufunc ,
6113
+ .method = ufuncimpl ,
6114
+ .descriptors = operation_descrs ,
6115
+ };
6116
+
6117
+ NPY_ARRAYMETHOD_FLAGS flags ;
6118
+ /* Use contiguous strides; if there is such a loop it may be faster */
6119
+ npy_intp strides [3 ] = {
6120
+ operation_descrs [0 ]-> elsize , operation_descrs [1 ]-> elsize , 0 };
6121
+ if (nop == 3 ) {
6122
+ strides [2 ] = operation_descrs [2 ]-> elsize ;
6123
+ }
6124
+
6125
+ if (ufuncimpl -> get_strided_loop (& context , 1 , 0 , strides ,
6126
+ & strided_loop , & auxdata , & flags ) < 0 ) {
6127
+ goto fail ;
6128
+ }
6129
+ int needs_api = (flags & NPY_METH_REQUIRES_PYAPI ) != 0 ;
6130
+ needs_api |= NpyIter_IterationNeedsAPI (iter_buffer );
6131
+ if (!(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS )) {
6132
+ /* Start with the floating-point exception flags cleared */
6133
+ npy_clear_floatstatus_barrier ((char * )& iter );
6134
+ }
6135
+
6086
6136
if (!needs_api ) {
6087
6137
NPY_BEGIN_THREADS ;
6088
6138
}
@@ -6091,14 +6141,13 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
6091
6141
* Iterate over first and second operands and call ufunc
6092
6142
* for each pair of inputs
6093
6143
*/
6094
- i = iter -> size ;
6095
- while ( i > 0 )
6144
+ int res = 0 ;
6145
+ for ( npy_intp i = iter -> size ; i > 0 ; i -- )
6096
6146
{
6097
6147
char * dataptr [3 ];
6098
6148
char * * buffer_dataptr ;
6099
6149
/* one element at a time, no stride required but read by innerloop */
6100
- npy_intp count [3 ] = {1 , 0xDEADBEEF , 0xDEADBEEF };
6101
- npy_intp stride [3 ] = {0xDEADBEEF , 0xDEADBEEF , 0xDEADBEEF };
6150
+ npy_intp count = 1 ;
6102
6151
6103
6152
/*
6104
6153
* Set up data pointers for either one or two input operands.
@@ -6117,14 +6166,14 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
6117
6166
/* Reset NpyIter data pointers which will trigger a buffer copy */
6118
6167
NpyIter_ResetBasePointers (iter_buffer , dataptr , & err_msg );
6119
6168
if (err_msg ) {
6169
+ res = -1 ;
6120
6170
break ;
6121
6171
}
6122
6172
6123
6173
buffer_dataptr = NpyIter_GetDataPtrArray (iter_buffer );
6124
6174
6125
- innerloop (buffer_dataptr , count , stride , innerloopdata );
6126
-
6127
- if (needs_api && PyErr_Occurred ()) {
6175
+ res = strided_loop (& context , buffer_dataptr , & count , strides , auxdata );
6176
+ if (res != 0 ) {
6128
6177
break ;
6129
6178
}
6130
6179
@@ -6138,32 +6187,35 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
6138
6187
if (iter2 != NULL ) {
6139
6188
PyArray_ITER_NEXT (iter2 );
6140
6189
}
6141
-
6142
- i -- ;
6143
6190
}
6144
6191
6145
6192
NPY_END_THREADS ;
6146
6193
6147
- if (err_msg ) {
6194
+ if (res != 0 && err_msg ) {
6148
6195
PyErr_SetString (PyExc_ValueError , err_msg );
6149
6196
}
6197
+ if (res == 0 && !(flags & NPY_METH_NO_FLOATINGPOINT_ERRORS )) {
6198
+ /* NOTE: We could check float errors even when `res < 0` */
6199
+ res = _check_ufunc_fperr (errormask , NULL , "at" );
6200
+ }
6150
6201
6202
+ NPY_AUXDATA_FREE (auxdata );
6151
6203
NpyIter_Deallocate (iter_buffer );
6152
6204
6153
6205
Py_XDECREF (op2_array );
6154
6206
Py_XDECREF (iter );
6155
6207
Py_XDECREF (iter2 );
6156
- for (i = 0 ; i < 3 ; i ++ ) {
6157
- Py_XDECREF (dtypes [i ]);
6208
+ for (int i = 0 ; i < 3 ; i ++ ) {
6209
+ Py_XDECREF (operation_descrs [i ]);
6158
6210
Py_XDECREF (array_operands [i ]);
6159
6211
}
6160
6212
6161
6213
/*
6162
- * An error should only be possible if needs_api is true, but this is not
6163
- * strictly correct for old-style ufuncs (e.g. `power` released the GIL
6164
- * but manually set an Exception).
6214
+ * An error should only be possible if needs_api is true or `res != 0`,
6215
+ * but this is not strictly correct for old-style ufuncs
6216
+ * (e.g. `power` released the GIL but manually set an Exception).
6165
6217
*/
6166
- if (PyErr_Occurred ()) {
6218
+ if (res != 0 || PyErr_Occurred ()) {
6167
6219
return NULL ;
6168
6220
}
6169
6221
else {
@@ -6178,10 +6230,11 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
6178
6230
Py_XDECREF (op2_array );
6179
6231
Py_XDECREF (iter );
6180
6232
Py_XDECREF (iter2 );
6181
- for (i = 0 ; i < 3 ; i ++ ) {
6182
- Py_XDECREF (dtypes [i ]);
6233
+ for (int i = 0 ; i < 3 ; i ++ ) {
6234
+ Py_XDECREF (operation_descrs [i ]);
6183
6235
Py_XDECREF (array_operands [i ]);
6184
6236
}
6237
+ NPY_AUXDATA_FREE (auxdata );
6185
6238
6186
6239
return NULL ;
6187
6240
}
0 commit comments