diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h index c47c46a66c9b..6519f69ddee5 100644 --- a/numpy/core/src/private/ufunc_override.h +++ b/numpy/core/src/private/ufunc_override.h @@ -6,6 +6,7 @@ #include #include "numpy/ufuncobject.h" +/* Normalize different ufunc methods */ static void normalize___call___args(PyUFuncObject *ufunc, PyObject *args, PyObject **normal_args, PyObject **normal_kwds, @@ -152,6 +153,64 @@ normalize_at_args(PyUFuncObject *ufunc, PyObject *args, return; } +/* + * Check args for a object with a `__numpy_ufunc__` attribute. + */ +static int +check_tuple_for_numpy_ufunc(PyObject *args, int *noa, + int *i, PyObject *with_override[NPY_MAXARGS], + int with_override_pos[NPY_MAXARGS]) { + int og_i = *i; /* original i */ + int og_noa = *noa; /* original number of overriding args */ + int nargs = PyTuple_GET_SIZE(args); + PyObject *obj; + + for (; *i < (og_i + nargs); ++*i) { + + obj = PyTuple_GET_ITEM(args, *i); + + /* short circuit optimization for common cases. */ + if (PyArray_CheckExact(obj) || PyArray_IsScalar(obj, Generic) || + _is_basic_python_type(obj)) { + continue; + } + + if (PyObject_HasAttrString(obj, "__numpy_ufunc__")) { + with_override[*noa] = obj; + with_override_pos[*noa] = *i; + ++*noa; + } + } + if ((noa - og_noa) > 0) { + return 1; + } + return 0; +} + +static int +check_kwds_for_numpy_ufunc(PyObject *kwds, int *noa, int *i, + PyObject **with_override, + int with_override_pos[NPY_MAXARGS]) { + PyObject *obj; + if ((kwds)&& (PyDict_CheckExact(kwds))) { + obj = PyDict_GetItemString(kwds, "out"); + if (obj != NULL) { + if (PyObject_HasAttrString(obj, "__numpy_ufunc__")) { + with_override[*noa] = obj; + with_override_pos[*noa] = *i; + ++*noa; + return 1; + } + if PyTuple_CheckExact(obj) { + return check_tuple_for_numpy_ufunc(obj, noa, i, + with_override, + with_override_pos); + } + } + } + return 0; +} + /* * Check a set of args for the `__numpy_ufunc__` method. If more than one of * the input arguments implements `__numpy_ufunc__`, they are tried in the @@ -169,11 +228,10 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, PyObject **result, int nin) { - int i; - int override_pos; /* Position of override in args.*/ + int i = 0; + int override_pos; /* Position of winning override in args.*/ int j; - int nargs = PyTuple_GET_SIZE(args); int noa = 0; /* Number of overriding args.*/ PyObject *obj; @@ -205,22 +263,8 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, goto fail; } - for (i = 0; i < nargs; ++i) { - obj = PyTuple_GET_ITEM(args, i); - /* - * TODO: could use PyArray_GetAttrString_SuppressException if it - * weren't private to multiarray.so - */ - if (PyArray_CheckExact(obj) || PyArray_IsScalar(obj, Generic) || - _is_basic_python_type(obj)) { - continue; - } - if (PyObject_HasAttrString(obj, "__numpy_ufunc__")) { - with_override[noa] = obj; - with_override_pos[noa] = i; - ++noa; - } - } + check_tuple_for_numpy_ufunc(args, &noa, &i, with_override, with_override_pos); + check_kwds_for_numpy_ufunc(kwds, &noa, &i, with_override, with_override_pos); /* No overrides, bail out.*/ if (noa == 0) { diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index b0d6770527f5..1a19f87e7600 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -2240,15 +2240,15 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw): return "ufunc" else: inputs = list(inputs) - inputs[i] = np.asarray(self) + if i < len(inputs): + inputs[i] = np.asarray(self) func = getattr(ufunc, method) + if ('out' in kw) and (kw['out'] is not None): + kw['out'] = np.asarray(kw['out']) r = func(*inputs, **kw) - if 'out' in kw: - return r - else: - x = SomeClass2(r.shape, dtype=r.dtype) - x[...] = r - return x + x = SomeClass2(r.shape, dtype=r.dtype) + x[...] = r + return x arr = np.array([0]) obj = SomeClass() @@ -2276,6 +2276,23 @@ def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw): assert_equal(obj2.sum(), 42) assert_(isinstance(obj2, SomeClass2)) + def test_out_override(self): + # regression test for github bug 4753 + class OutClass(ndarray): + def __numpy_ufunc__(self, ufunc, method, i, inputs, **kw): + if 'out' in kw: + tmp_kw = kw.copy() + tmp_kw.pop('out') + func = getattr(ufunc, method) + kw['out'][...] = func(*inputs, **tmp_kw) + + A = np.array([0]).view(OutClass) + B = np.array([5]) + C = np.array([6]) + np.multiply(C, B, out=A) + assert_equal(A[0], 30) + assert_(isinstance(A, OutClass)) + class TestCAPI(TestCase): def test_IsPythonScalar(self):