diff --git a/numpy/add_newdocs.py b/numpy/add_newdocs.py index e8030d5621ff..e4ede318687d 100644 --- a/numpy/add_newdocs.py +++ b/numpy/add_newdocs.py @@ -6277,6 +6277,27 @@ def luf(lamdaexpr, *args, **kwargs): """)) + +############################################################################## +# +# Documentation for ufunc_wrapper +# +############################################################################## + +add_newdoc('numpy.core', 'ufunc_wraper', + """ + Decorator class to allow a ufunc-like class method to use the `__array_ufunc__` + mechanism + + Examples + -------- + + >>> class ArrayLike: + ... pass + + + + """) ############################################################################## # # Documentation for dtype attributes and methods diff --git a/numpy/core/include/numpy/ufuncobject.h b/numpy/core/include/numpy/ufuncobject.h index d0ac1fd7d732..30c9cece6ee7 100644 --- a/numpy/core/include/numpy/ufuncobject.h +++ b/numpy/core/include/numpy/ufuncobject.h @@ -111,15 +111,20 @@ typedef int (PyUFunc_MaskedInnerLoopSelectionFunc)( NpyAuxData **out_innerloopdata, int *out_needs_api); -typedef struct _tagPyUFuncObject { - PyObject_HEAD - /* +/* used in both ufunc and ufunc_wrapper */ +#define UFUNC_BASE \ + PyObject_HEAD /* * nin: Number of inputs * nout: Number of outputs * nargs: Always nin + nout (Why is it stored?) - */ - int nin, nout, nargs; + */ int nin, nout, nargs; +typedef struct { + UFUNC_BASE +} PyUFuncBaseObject; + +typedef struct _tagPyUFuncObject { + UFUNC_BASE /* Identity for reduction, either PyUFunc_One or PyUFunc_Zero */ int identity; diff --git a/numpy/core/numeric.py b/numpy/core/numeric.py index 7ade3d22451a..45ccdd69b1df 100644 --- a/numpy/core/numeric.py +++ b/numpy/core/numeric.py @@ -28,7 +28,8 @@ from .multiarray import newbuffer, getbuffer from . import umath -from .umath import (multiply, invert, sin, UFUNC_BUFSIZE_DEFAULT, +from .umath import (multiply, invert, sin, ufunc_wrapper, + UFUNC_BUFSIZE_DEFAULT, ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG, ERR_DEFAULT, PINF, NAN) from . import numerictypes @@ -36,6 +37,7 @@ from ._internal import TooHardError, AxisError bitwise_not = invert +# TODO properly export ufunc from umath ufunc = type(sin) newaxis = None @@ -48,6 +50,32 @@ import __builtin__ as builtins +# TODO: rewrite the __call__ in C? +class UFuncWrapper(ufunc_wrapper): + ''' A decorator to wrap ufunc-like class methods and enable the + __array_ufunc__ protocol + + Use as + + class MyArray(ndarray): + @UFuncWrapper(2, 1) # nin, nout + def solve(self, other): + ... + return result + ''' + def __call__(self, meth): + def wrap(*args, **kwds): + (status, r) = self.check_override(*args, **kwds) + if status > 0: + return r + return meth(*args, **kwds) + wrap.__name__ = meth.__name__ + wrap.__doc__ = meth.__doc__ + self.__name__ = meth.__name__ + self.__doc__ = meth.__doc__ + return wrap + + def loads(*args, **kwargs): # NumPy 1.15.0, 2017-12-10 warnings.warn( @@ -74,7 +102,7 @@ def loads(*args, **kwargs): 'False_', 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS', 'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like', 'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS', - 'MAY_SHARE_EXACT', 'TooHardError', 'AxisError'] + 'MAY_SHARE_EXACT', 'TooHardError', 'AxisError', 'UFuncWrapper'] if sys.version_info[0] < 3: __all__.extend(['getbuffer', 'newbuffer']) diff --git a/numpy/core/src/multiarray/number.c b/numpy/core/src/multiarray/number.c index 14389a925843..9cfda3c2e78e 100644 --- a/numpy/core/src/multiarray/number.c +++ b/numpy/core/src/multiarray/number.c @@ -369,13 +369,44 @@ array_divmod(PyArrayObject *m1, PyObject *m2) static PyObject * array_matrix_multiply(PyArrayObject *m1, PyObject *m2) { - static PyObject *matmul = NULL; - - npy_cache_import("numpy.core.multiarray", "matmul", &matmul); - if (matmul == NULL) { - return NULL; + static PyObject *matmul=NULL, *wrapper=NULL, *ufunc=NULL, *checker=NULL; + PyObject *result = Py_None, *res_tuple; + int status; + if (ufunc == NULL) { + PyObject *s; + npy_cache_import("numpy.core.multiarray", "matmul", &matmul); + if (matmul == NULL) { + return NULL; + } + npy_cache_import("numpy.core.umath", "ufunc_wrapper", &wrapper); + if (wrapper == NULL) { + return NULL; + } + checker = PyObject_GetAttrString(wrapper, "check_override"); + if (checker == NULL) { + return NULL; + } + s = Py_BuildValue("ii", 2, 1); + ufunc = PyObject_CallObject(wrapper, s); + Py_DECREF(s); + if (ufunc == NULL) { + return NULL; + } } BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_matrix_multiply, array_matrix_multiply); + res_tuple = PyObject_CallFunctionObjArgs(checker, ufunc, m1, m2, NULL); + if (PyArg_ParseTuple(res_tuple, "iO", &status, &result) < 0) { + Py_DECREF(res_tuple); + return NULL; + } + if (status > 0) { + Py_INCREF(result); + Py_DECREF(res_tuple); + return result; + } + Py_DECREF(res_tuple); + if (PyErr_Occurred()) + return NULL; return PyArray_GenericBinaryFunction(m1, m2, matmul); } diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c index e405155cf9a5..6faab1d39e03 100644 --- a/numpy/core/src/private/ufunc_override.c +++ b/numpy/core/src/private/ufunc_override.c @@ -41,6 +41,12 @@ get_non_default_array_ufunc(PyObject *obj) return NULL; } /* does the class define __array_ufunc__? */ +#if PY_VERSION_HEX < 0x03000000 + if (Py_TYPE(obj) == &PyInstance_Type) { + PyErr_Warn(PyExc_RuntimeWarning, + "cannot lookup __array_ufunc__ on old-style classes"); + } +#endif cls_array_ufunc = PyArray_LookupSpecial(obj, "__array_ufunc__"); if (cls_array_ufunc == NULL) { return NULL; diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c index 123d9af87965..38ec5f12a2d9 100644 --- a/numpy/core/src/umath/override.c +++ b/numpy/core/src/umath/override.c @@ -37,7 +37,7 @@ normalize_signature_keyword(PyObject *normal_kwds) } static int -normalize___call___args(PyUFuncObject *ufunc, PyObject *args, +normalize___call___args(PyUFuncBaseObject *ufunc, PyObject *args, PyObject **normal_args, PyObject **normal_kwds) { /* @@ -114,7 +114,7 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args, } static int -normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args, +normalize_reduce_args(PyUFuncBaseObject *ufunc, PyObject *args, PyObject **normal_args, PyObject **normal_kwds) { /* @@ -169,7 +169,7 @@ normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args, } static int -normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args, +normalize_accumulate_args(PyUFuncBaseObject *ufunc, PyObject *args, PyObject **normal_args, PyObject **normal_kwds) { /* @@ -215,7 +215,7 @@ normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args, } static int -normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args, +normalize_reduceat_args(PyUFuncBaseObject *ufunc, PyObject *args, PyObject **normal_args, PyObject **normal_kwds) { /* @@ -263,7 +263,7 @@ normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args, } static int -normalize_outer_args(PyUFuncObject *ufunc, PyObject *args, +normalize_outer_args(PyUFuncBaseObject *ufunc, PyObject *args, PyObject **normal_args, PyObject **normal_kwds) { /* @@ -297,7 +297,7 @@ normalize_outer_args(PyUFuncObject *ufunc, PyObject *args, } static int -normalize_at_args(PyUFuncObject *ufunc, PyObject *args, +normalize_at_args(PyUFuncBaseObject *ufunc, PyObject *args, PyObject **normal_args, PyObject **normal_kwds) { /* ufunc.at(a, indices[, b]) */ @@ -325,7 +325,7 @@ normalize_at_args(PyUFuncObject *ufunc, PyObject *args, * result of the operation, if any. If *result is NULL, there is no override. */ NPY_NO_EXPORT int -PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, +PyUFunc_CheckOverride(PyUFuncBaseObject *ufunc, char *method, PyObject *args, PyObject *kwds, PyObject **result) { diff --git a/numpy/core/src/umath/override.h b/numpy/core/src/umath/override.h index 68f3c6ef0814..4c27127845fe 100644 --- a/numpy/core/src/umath/override.h +++ b/numpy/core/src/umath/override.h @@ -5,7 +5,7 @@ #include "numpy/ufuncobject.h" NPY_NO_EXPORT int -PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, +PyUFunc_CheckOverride(PyUFuncBaseObject *ufunc, char *method, PyObject *args, PyObject *kwds, PyObject **result); #endif diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index af415362bd33..9ccac661d83e 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -4368,7 +4368,8 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) ufunc_full_args full_args = {NULL, NULL}; int errval; - errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override); + errval = PyUFunc_CheckOverride((PyUFuncBaseObject*)ufunc, "__call__", + args, kwds, &override); if (errval) { return NULL; } @@ -5055,7 +5056,8 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) PyObject *new_args, *tmp; PyObject *shape1, *shape2, *newshape; - errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override); + errval = PyUFunc_CheckOverride((PyUFuncBaseObject*)ufunc, "outer", + args, kwds, &override); if (errval) { return NULL; } @@ -5151,7 +5153,8 @@ ufunc_reduce(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) int errval; PyObject *override = NULL; - errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override); + errval = PyUFunc_CheckOverride((PyUFuncBaseObject*)ufunc, "reduce", + args, kwds, &override); if (errval) { return NULL; } @@ -5167,7 +5170,8 @@ ufunc_accumulate(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) int errval; PyObject *override = NULL; - errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override); + errval = PyUFunc_CheckOverride((PyUFuncBaseObject*)ufunc, "accumulate", + args, kwds, &override); if (errval) { return NULL; } @@ -5183,7 +5187,8 @@ ufunc_reduceat(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) int errval; PyObject *override = NULL; - errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override); + errval = PyUFunc_CheckOverride((PyUFuncBaseObject*)ufunc, "reduceat", + args, kwds, &override); if (errval) { return NULL; } @@ -5247,7 +5252,8 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) char * err_msg = NULL; NPY_BEGIN_THREADS_DEF; - errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override); + errval = PyUFunc_CheckOverride((PyUFuncBaseObject*)ufunc, "at", + args, NULL, &override); if (errval) { return NULL; } @@ -5781,3 +5787,181 @@ NPY_NO_EXPORT PyTypeObject PyUFunc_Type = { }; /* End of code for ufunc objects */ + +typedef struct { + UFUNC_BASE + PyObject * class; +} PyUFuncWrapperObject; + +static void +ufuncwrapper_dealloc(PyUFuncWrapperObject *ufunc) +{ + Py_XDECREF(ufunc->class); + Py_TYPE(ufunc)->tp_free(ufunc); +} + +static PyObject * +ufuncwrapper_repr(PyUFuncWrapperObject *ufunc) +{ + PyObject * name = NULL; + PyObject * str; + if (ufunc->class) { + name = PyObject_GetAttrString(ufunc->class, "__name__"); + } + if (name != NULL) { + str = PyUString_FromFormat("", PyString_AsString(name)); + } + else { + str = PyUString_FromFormat("", "?"); + } + Py_XDECREF(name); + return str; +} + +#if 0 +static PyObject * myfunc(PyObject* self, PyObject * a) { return a; }; + +static PyObject * +ufuncwrapper_call(PyObject *self, PyObject *args, PyObject *kwargs) { + /* Not yet working, moved the code to pure python in numeric.py */ + PyFunctionObject * meth; + PyObject *x; + PyUFuncWrapperObject *s = (PyUFuncWrapperObject*)self; + PyMethodDef methdef; + if (!PyArg_ParseTuple(args, "O", &meth)) { + return NULL; + } + if (!PyCallable_Check((PyObject*)meth)) { + PyErr_SetString(PyExc_TypeError, "wrapped object must be a bound class method"); + return NULL; + } + //methdef.ml_name = PyBytes_AsString(meth->func_name); + methdef.ml_name = "wrapped function"; + methdef.ml_flags = METH_VARARGS | METH_KEYWORDS; + methdef.ml_meth = myfunc; + methdef.ml_doc = meth->func_doc; + x = PyCFunction_New(&methdef, NULL); + meth = PyMethod_New(x, s->class); + Py_DECREF(s); + return meth; +} +#endif + +static PyObject *ufuncwrapper_new(PyTypeObject *t, PyObject *a, PyObject *k) +{ + PyObject *o; + o = t->tp_alloc(t, 0); + return o; +} + +static int +ufuncwrapper_init(PyUFuncWrapperObject * self, PyObject *args, PyObject *kwargs) +{ + if (!PyArg_ParseTuple(args, "ii", &self->nin, &self->nout)) + return -1; + + self->nargs = self->nin + self->nout; + return 0; +} + +static PyObject * +ufunc_check_override(PyObject *self, PyObject *args, PyObject *kwds) { + PyUFuncBaseObject *ufunc = (PyUFuncBaseObject*)self; + char * method = "__call__"; + PyObject * result = NULL; + int status; + status = PyUFunc_CheckOverride(ufunc, method, args, kwds, &result); + if (status) { + return NULL; + } + else if (result) { + return Py_BuildValue("iO", 1, result); + } + return Py_BuildValue("iO", 0, Py_None); +} + +static struct PyMethodDef ufuncwrapper_methods[] = { + {"check_override", + (PyCFunction)ufunc_check_override, + METH_VARARGS | METH_KEYWORDS, NULL}, + {NULL, NULL, 0, NULL} /* sentinel */ +}; + + +static PyGetSetDef ufuncwrapper_getset[] = { + {"nin", + (getter)ufunc_get_nin, + NULL, NULL, NULL}, + {"nout", + (getter)ufunc_get_nout, + NULL, NULL, NULL}, + {"nargs", + (getter)ufunc_get_nargs, + NULL, NULL, NULL}, + {NULL, NULL, NULL, NULL, NULL}, /* Sentinel */ +}; + +/****************************************************************************** + *** UFUNC_WRAPPER TYPE OBJECT *** + *****************************************************************************/ + +NPY_NO_EXPORT PyTypeObject PyUFuncWrapper_Type = { +#if defined(NPY_PY3K) + PyVarObject_HEAD_INIT(NULL, 0) +#else + PyObject_HEAD_INIT(NULL) + 0, /* ob_size */ +#endif + "numpy.ufunc_wrapper", /* tp_name */ + sizeof(PyUFuncWrapperObject), /* tp_basicsize */ + 0, /* tp_itemsize */ + /* methods */ + (destructor)ufuncwrapper_dealloc, /* tp_dealloc */ + 0, /* tp_print */ + 0, /* tp_getattr */ + 0, /* tp_setattr */ +#if defined(NPY_PY3K) + 0, /* tp_reserved */ +#else + 0, /* tp_compare */ +#endif + (reprfunc)ufuncwrapper_repr, /* tp_repr */ + 0, /* tp_as_number */ + 0, /* tp_as_sequence */ + 0, /* tp_as_mapping */ + 0, /* tp_hash */ + 0, /* tp_call */ + 0, /* tp_str */ + 0, /* tp_getattro */ + 0, /* tp_setattro */ + 0, /* tp_as_buffer */ + Py_TPFLAGS_DEFAULT |Py_TPFLAGS_BASETYPE, /* tp_flags */ + 0, /* tp_doc */ + 0, /* tp_traverse */ + 0, /* tp_clear */ + 0, /* tp_richcompare */ + 0, /* tp_weaklistoffset */ + 0, /* tp_iter */ + 0, /* tp_iternext */ + ufuncwrapper_methods, /* tp_methods */ + 0, /* tp_members */ + ufuncwrapper_getset, /* tp_getset */ + 0, /* tp_base */ + 0, /* tp_dict */ + 0, /* tp_descr_get */ + 0, /* tp_descr_set */ + 0, /* tp_dictoffset */ + (initproc)ufuncwrapper_init, /* tp_init */ + 0, /* tp_alloc */ + ufuncwrapper_new, /* tp_new */ + 0, /* tp_free */ + 0, /* tp_is_gc */ + 0, /* tp_bases */ + 0, /* tp_mro */ + 0, /* tp_cache */ + 0, /* tp_subclasses */ + 0, /* tp_weaklist */ + 0, /* tp_del */ + 0, /* tp_version_tag */ +}; + diff --git a/numpy/core/src/umath/umathmodule.c b/numpy/core/src/umath/umathmodule.c index 5567b9bbfab1..262ab8165444 100644 --- a/numpy/core/src/umath/umathmodule.c +++ b/numpy/core/src/umath/umathmodule.c @@ -253,6 +253,7 @@ intern_strings(void) /* Setup the umath module */ /* Remove for time being, it is declared in __ufunc_api.h */ /*static PyTypeObject PyUFunc_Type;*/ +extern PyTypeObject PyUFuncWrapper_Type; static struct PyMethodDef methods[] = { {"frompyfunc", @@ -323,6 +324,10 @@ PyMODINIT_FUNC initumath(void) if (PyType_Ready(&PyUFunc_Type) < 0) goto err; + if (PyType_Ready(&PyUFuncWrapper_Type) < 0) + goto err; + + PyModule_AddObject(m,"ufunc_wrapper",(PyObject*)&PyUFuncWrapper_Type); /* Add some symbolic constants to the module */ d = PyModule_GetDict(m); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 3ca201eddbed..6ebd9ca14998 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -3329,7 +3329,7 @@ def __mul__(self, other): def __div__(self, other): raise AssertionError('__div__ should not be called') - + def __pow__(self, exp): return SomeClass(num=self.num ** exp) @@ -3349,7 +3349,53 @@ def pow_for(exp, arr): assert_equal(obj_arr ** 1, pow_for(1, obj_arr)) assert_equal(obj_arr ** -1, pow_for(-1, obj_arr)) assert_equal(obj_arr ** 2, pow_for(2, obj_arr)) - + + def test_wrapper(self): + + class Solver(np.ndarray): + @np.UFuncWrapper(2, 1) + def solve(self, other, out=None): + # Solve for other given self + if out: + out[...] = (self + other.T) / 2.0 + return out + return (self + other.T) / 2.0 + + class Overrider(object): + def __array_ufunc__(self, ufunc, method, *args, **kwargs): + return ufunc.__name__, method + + @property + def T(self): + raise TypeError('__array_ufunc__ override not successful') + + class InstanceOverrider(): + # old style class on Python2, will warn and not override + def __array_ufunc__(self, ufunc, method, *args, **kwargs): + return ufunc.__name__, method + + @property + def T(self): + return 4 + + s = np.arange(12).reshape(3,4).view(Solver) + a = np.arange(11, -1, -1).reshape(3,4).T + o = Overrider() + assert_equal(s.solve(a), 5.5) + assert_equal(s.solve(o), ('solve', '__call__')) + + i = InstanceOverrider() + with suppress_warnings() as sup: + sup.record(RuntimeWarning) + ISPY2 = sys.version_info[0] < 3 + if ISPY2: + # did not override via __array_ufunc__ + assert_equal(s.solve(i), (s + 4) / 2.0) + assert len(sup.log) == 1 + else: + assert_equal(s.solve(o), ('solve', '__call__')) + assert len(sup.log) == 0 + class TestTemporaryElide(object): # elision is only triggered on relatively large arrays @@ -5602,6 +5648,7 @@ def test_out_arg(self): class TestMatmulOperator(MatmulCommon): import operator matmul = operator.matmul + mul = operator.mul def test_array_priority_override(self): @@ -5619,6 +5666,29 @@ def __rmatmul__(self, other): assert_equal(self.matmul(a, b), "A") assert_equal(self.matmul(b, a), "A") + def test_array_ufunc_override(self): + class OptOut: + __array_ufunc__ = None + def __rmatmul__(self, other): + return 'rmatmul' + + class OtherArray: + def __array_ufunc__(self, *args, **kwargs): + return 'array_ufunc' + def __rmatmul__(self, other): + return 'rmatmul' + + array = np.arange(3) + opt_out = OptOut() + other_array = OtherArray() + + # OptOut works as expected: + assert_raises(TypeError, self.mul, array, opt_out) + assert_equal(self.matmul(array, opt_out), 'rmatmul') + + assert_equal(array * other_array, 'array_ufunc') + assert_equal(self.matmul(array, other_array), 'array_ufunc') + def test_matmul_inplace(): # It would be nice to support in-place matmul eventually, but for now # we don't have a working implementation, so better just to error out