8000 MAINT: Factor out array prepare/wrap method look-up on inputs. · eric-wieser/numpy@43cdf9f · GitHub
[go: up one dir, main page]

Skip to content

Commit 43cdf9f

Browse files
mhvkMarten H. van Kerkwijk
authored andcommitted
MAINT: Factor out array prepare/wrap method look-up on inputs.
1 parent 7255987 commit 43cdf9f

File tree

1 file changed

+73
-97
lines changed

1 file changed

+73
-97
lines changed

numpy/core/src/umath/ufunc_object.c

Lines changed: 73 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,65 @@ PyUFunc_clearfperr()
125125
npy_clear_floatstatus();
126126
}
127127

128+
/*
129+
* This function analyzes the input arguments and determines an appropriate
130+
* method (__array_prepare__ or __array_wrap__) function to call, taking it
131+
* from the input with the highest priority. Return NULL if no argument
132+
* defines the method.
133+
*/
134+
static PyObject*
135+
_find_array_method(PyObject *args, int nin, PyObject *method_name)
136+
{
137+
int i, n_methods;
138+
PyObject *obj;
139+
PyObject *with_method[NPY_MAXARGS], *methods[NPY_MAXARGS];
140+
PyObject *method = NULL;
141+
142+
n_methods = 0;
143+
for (i = 0; i < nin; i++) {
144+
obj = PyTuple_GET_ITEM(args, i);
145+
if (PyArray_CheckExact(obj) || PyArray_IsAnyScalar(obj)) {
146+
continue;
147+
}
148+
method = PyObject_GetAttr(obj, method_name);
149+
if (method) {
150+
if (PyCallable_Check(method)) {
151+
with_method[n_methods] = obj;
152+
methods[n_methods] = method;
153+
++n_methods;
154+
}
155+
else {
156+
Py_DECREF(method);
157+
method = NULL;
158+
}
159+
}
160+
else {
161+
PyErr_Clear();
162+
}
163+
}
164+
if (n_methods > 0) {
165+
/* If we have some methods defined, find the one of highest priority */
166+
method = methods[0];
167+
if (n_methods > 1) {
168+
double maxpriority = PyArray_GetPriority(with_method[0],
169+
NPY_PRIORITY);
170+
for (i = 1; i < n_methods; ++i) {
171+
double priority = PyArray_GetPriority(with_method[i],
172+
NPY_PRIORITY);
173+
if (priority > maxpriority) {
174+
maxpriority = priority;
175+
Py_DECREF(method);
176+
method = methods[i];
177+
}
178+
else {
179+
Py_DECREF(methods[i]);
180+
}
181+
}
182+
}
183+
}
184+
return method;
185+
}
186+
128187
/*
129188
* This function analyzes the input arguments
130189
* and determines an appropriate __array_prepare__ function to call
@@ -149,9 +208,7 @@ _find_array_prepare(PyObject *args, PyObject *kwds,
149208
{
150209
Py_ssize_t nargs;
151210
int i;
152-
int np = 0;
153-
PyObject *with_prep[NPY_MAXARGS], *preps[NPY_MAXARGS];
154-
PyObject *obj, *prep = NULL;
211+
PyObject *obj, *prep;
155212

156213
/*
157214
* If a 'subok' parameter is passed and isn't True, don't wrap
@@ -167,53 +224,12 @@ _find_array_prepare(PyObject *args, PyObject *kwds,
167224
}
168225
}
169226

170-
nargs = PyTuple_GET_SIZE(args);
171-
for (i = 0; i < nin; i++) {
172-
obj = PyTuple_GET_ITEM(args, i);
173-
if (PyArray_CheckExact(obj) || PyArray_IsAnyScalar(obj)) {
174-
continue;
175-
}
176-
prep = PyObject_GetAttr(obj, npy_um_str_array_prepare);
177-
if (prep) {
178-
if (PyCallable_Check(prep)) {
179-
with_prep[np] = obj;
180-
preps[np] = prep;
181-
++np;
182-
}
183-
else {
184-
Py_DECREF(prep);
185-
prep = NULL;
186-
}
187-
}
188-
else {
189-
PyErr_Clear();
190-
}
191-
}
192-
if (np > 0) {
193-
/* If we have some preps defined, find the one of highest priority */
194-
prep = preps[0];
195-
if (np > 1) {
196-
double maxpriority = PyArray_GetPriority(with_prep[0],
197-
NPY_PRIORITY);
198-
for (i = 1; i < np; ++i) {
199-
double priority = PyArray_GetPriority(with_prep[i],
200-
NPY_PRIORITY);
201-
if (priority > maxpriority) {
202-
maxpriority = priority;
203-
Py_DECREF(prep);
204-
prep = preps[i];
205-
}
206-
else {
207-
Py_DECREF(preps[i]);
208-
}
209-
}
210-
}
211-
}
212-
213227
/*
214-
* Here prep is the prepping function determined from the
215-
* input arrays (could be NULL).
216-
*
228+
* Determine the prepping function given by the input arrays
229+
* (could be NULL).
230+
*/
231+
prep = _find_array_method(args, nin, npy_um_str_array_prepare);
232+
/*
217233
* For all the output arrays decide what to do.
218234
*
219235
* 1) Use the prep function determined from the input arrays
@@ -225,6 +241,7 @@ _find_array_prepare(PyObject *args, PyObject *kwds,
225241
* exact ndarray so that no PyArray_Return is
226242
* done in that case.
227243
*/
244+
nargs = PyTuple_GET_SIZE(args);
228245
for (i = 0; i < nout; i++) {
229246
int j = nin + i;
230247
int incref = 1;
@@ -3946,9 +3963,8 @@ _find_array_wrap(PyObject *args, PyObject *kwds,
39463963
{
39473964
Py_ssize_t nargs;
39483965
int i, idx_offset, start_idx;
3949-
int np = 0;
3950-
PyObject *with_wrap[NPY_MAXARGS], *wraps[NPY_MAXARGS];
3951-
PyObject *obj, *wrap = NULL;
3966+
PyObject *obj;
3967+
PyObject *wrap = NULL;
39523968

39533969
/*
39543970
* If a 'subok' parameter is passed and isn't True, don't wrap but put None
@@ -3962,53 +3978,13 @@ _find_array_wrap(PyObject *args, PyObject *kwds,
39623978
}
39633979
}
39643980

3965-
3966-
for (i = 0; i < nin; i++) {
3967-
obj = PyTuple_GET_ITEM(args, i);
3968-
if (PyArray_CheckExact(obj) || PyArray_IsAnyScalar(obj)) {
3969-
continue;
3970-
}
3971-
wrap = PyObject_GetAttr(obj, npy_um_str_array_wrap);
3972-
if (wrap) {
3973-
if (PyCallable_Check(wrap)) {
3974-
with_wrap[np] = obj;
3975-
wraps[np] = wrap;
3976-
++np;
3977-
}
3978-
else {
3979-
Py_DECREF(wrap);
3980-
wrap = NULL;
3981-
}
3982-
}
3983-
else {
3984-
PyErr_Clear();
3985-
}
3986-
}
3987-
if (np > 0) {
3988-
/* If we have some wraps defined, find the one of highest priority */
3989-
wrap = wraps[0];
3990-
if (np > 1) {
3991-
double maxpriority = PyArray_GetPriority(with_wrap[0],
3992-
NPY_PRIORITY);
3993-
for (i = 1; i < np; ++i) {
3994-
double priority = PyArray_GetPriority(with_wrap[i],
3995-
NPY_PRIORITY);
3996-
if (priority > maxpriority) {
3997-
maxpriority = priority;
3998-
Py_DECREF(wrap);
3999-
wrap = wraps[i];
4000-
}
4001-
else {
4002-
Py_DECREF(wraps[i]);
4003-
}
4004-
}
4005-
}
4006-
}
3981+
/*
3982+
* Determine the wrapping function given by the input arrays
3983+
* (could be NULL).
3984+
*/
3985+
wrap = _find_array_method(args, nin, npy_um_str_array_wrap);
40073986

40083987
/*
4009-
* Here wrap is the wrapping function determined from the
4010-
* input arrays (could be NULL).
4011-
*
40123988
* For all the output arrays decide what to do.
40133989
*
40143990
* 1) Use the wrap function determined from the input arrays

0 commit comments

Comments
 (0)
0