8000 WIP: ENH: allow __array_ufunc__ to override __matmul__ by mattip · Pull Request #11061 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

WIP: ENH: allow __array_ufunc__ to override __matmul__ #11061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions numpy/add_newdocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6277,6 +6277,27 @@ def luf(lamdaexpr, *args, **kwargs):

"""))


##############################################################################
#
# Documentation for ufunc_wrapper
#
##############################################################################

add_newdoc('numpy.core', 'ufunc_wraper',
"""
Decorator class to allow a ufunc-like class method to use the `__array_ufunc__`
mechanism

Examples
--------

>>> class ArrayLike:
... pass



""")
##############################################################################
#
# Documentation for dtype attributes and methods
Expand Down
15 changes: 10 additions & 5 deletions numpy/core/include/numpy/ufuncobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,20 @@ typedef int (PyUFunc_MaskedInnerLoopSelectionFunc)(
NpyAuxData **out_innerloopdata,
int *out_needs_api);

typedef struct _tagPyUFuncObject {
PyObject_HEAD
/*
/* used in both ufunc and ufunc_wrapper */
#define UFUNC_BASE \
PyObject_HEAD /*
* nin: Number of inputs
* nout: Number of outputs
* nargs: Always nin + nout (Why is it stored?)
*/
int nin, nout, nargs;
*/ int nin, nout, nargs;

typedef struct {
UFUNC_BASE
} PyUFuncBaseObject;

typedef struct _tagPyUFuncObject {
UFUNC_BASE
/* Identity for reduction, either PyUFunc_One or PyUFunc_Zero */
int identity;

Expand Down
32 changes: 30 additions & 2 deletions numpy/core/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
from .multiarray import newbuffer, getbuffer

from . import umath
from .umath import (multiply, invert, sin, UFUNC_BUFSIZE_DEFAULT,
from .umath import (multiply, invert, sin, ufunc_wrapper,
UFUNC_BUFSIZE_DEFAULT,
ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT,
ERR_LOG, ERR_DEFAULT, PINF, NAN)
from . import numerictypes
from .numerictypes import longlong, intc, int_, float_, complex_, bool_
from ._internal import TooHardError, AxisError

bitwise_not = invert
# TODO properly export ufunc from umath
ufunc = type(sin)
newaxis = None

Expand All @@ -48,6 +50,32 @@
import __builtin__ as builtins


# TODO: rewrite the __call__ in C?
class UFuncWrapper(ufunc_wrapper):
''' A decorator to wrap ufunc-like class methods and enable the
__array_ufunc__ protocol

Use as

class MyArray(ndarray):
@UFuncWrapper(2, 1) # nin, nout
def solve(self, other):
...
return result
'''
def __call__(self, meth):
def wrap(*args, **kwds):
(status, r) = self.check_override(*args, **kwds)
if status > 0:
return r
return meth(*args, **kwds)
wrap.__name__ = meth.__name__
wrap.__doc__ = meth.__doc__
self.__name__ = meth.__name__
self.__doc__ = meth.__doc__
return wrap


def loads(*args, **kwargs):
# NumPy 1.15.0, 2017-12-10
warnings.warn(
Expand All @@ -74,7 +102,7 @@ def loads(*args, **kwargs):
'False_', 'True_', 'bitwise_not', 'CLIP', 'RAISE', 'WRAP', 'MAXDIMS',
'BUFSIZE', 'ALLOW_THREADS', 'ComplexWarning', 'full', 'full_like',
'matmul', 'shares_memory', 'may_share_memory', 'MAY_SHARE_BOUNDS',
'MAY_SHARE_EXACT', 'TooHardError', 'AxisError']
'MAY_SHARE_EXACT', 'TooHardError', 'AxisError', 'UFuncWrapper']

if sys.version_info[0] < 3:
__all__.extend(['getbuffer', 'newbuffer'])
Expand Down
41 changes: 36 additions & 5 deletions numpy/core/src/multiarray/number.c
Original file line number Diff line number Diff line change
Expand Up @@ -369,13 +369,44 @@ array_divmod(PyArrayObject *m1, PyObject *m2)
static PyObject *
array_matrix_multiply(PyArrayObject *m1, PyObject *m2)
{
static PyObject *matmul = NULL;

npy_cache_import("numpy.core.multiarray", "matmul", &matmul);
if (matmul == NULL) {
return NULL;
static PyObject *matmul=NULL, *wrapper=NULL, *ufunc=NULL, *checker=NULL;
PyObject *result = Py_None, *res_tuple;
int status;
if (ufunc == NULL) {
PyObject *s;
npy_cache_import("numpy.core.multiarray", "matmul", &matmul);
if (matmul == NULL) {
return NULL;
}
npy_cache_import("numpy.core.umath", "ufunc_wrapper", &wrapper);
if (wrapper == NULL) {
return NULL;
}
checker = PyObject_GetAttrString(wrapper, "check_override");
if (checker == NULL) {
return NULL;
}
s = Py_BuildValue("ii", 2, 1);
ufunc = PyObject_CallObject(wrapper, s);
Py_DECREF(s);
if (ufunc == NULL) {
return NULL;
}
}
BINOP_GIVE_UP_IF_NEEDED(m1, m2, nb_matrix_multiply, array_matrix_multiply);
res_tuple = PyObject_CallFunctionObjArgs(checker, ufunc, m1, m2, NULL);
if (PyArg_ParseTuple(res_tuple, "iO", &status, &result) < 0) {
Py_DECREF(res_tuple);
return NULL;
}
if (status > 0) {
Py_INCREF(result);
Py_DECREF(res_tuple);
return result;
}
Py_DECREF(res_tuple);
if (PyErr_Occurred())
return NULL;
return PyArray_GenericBinaryFunction(m1, m2, matmul);
}

Expand Down
6 changes: 6 additions & 0 deletions numpy/core/src/private/ufunc_override.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ get_non_default_array_ufunc(PyObject *obj)
return NULL;
}
/* does the class define __array_ufunc__? */
#if PY_VERSION_HEX < 0x03000000
if (Py_TYPE(obj) == &PyInstance_Type) {
PyErr_Warn(PyExc_RuntimeWarning,
"cannot lookup __array_ufunc__ on old-style classes");
}
#endif
cls_array_ufunc = PyArray_LookupSpecial(obj, "__array_ufunc__");
if (cls_array_ufunc == NULL) {
return NULL;
Expand Down
14 changes: 7 additions & 7 deletions numpy/core/src/umath/override.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ normalize_signature_keyword(PyObject *normal_kwds)
}

static int
normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
normalize___call___args(PyUFuncBaseObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/*
Expand Down Expand Up @@ -114,7 +114,7 @@ normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
}

static int
normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args,
normalize_reduce_args(PyUFuncBaseObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/*
Expand Down Expand Up @@ -169,7 +169,7 @@ normalize_reduce_args(PyUFuncObject *ufunc, PyObject *args,
}

static int
normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args,
normalize_accumulate_args(PyUFuncBaseObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/*
Expand Down Expand Up @@ -215,7 +215,7 @@ normalize_accumulate_args(PyUFuncObject *ufunc, PyObject *args,
}

static int
normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args,
normalize_reduceat_args(PyUFuncBaseObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/*
Expand Down Expand Up @@ -263,7 +263,7 @@ normalize_reduceat_args(PyUFuncObject *ufunc, PyObject *args,
}

static int
normalize_outer_args(PyUFuncObject *ufunc, PyObject *args,
normalize_outer_args(PyUFuncBaseObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/*
Expand Down Expand Up @@ -297,7 +297,7 @@ normalize_outer_args(PyUFuncObject *ufunc, PyObject *args,
}

static int
normalize_at_args(PyUFuncObject *ufunc, PyObject *args,
normalize_at_args(PyUFuncBaseObject *ufunc, PyObject *args,
PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.at(a, indices[, b]) */
Expand Down Expand Up @@ -325,7 +325,7 @@ normalize_at_args(PyUFuncObject *ufunc, PyObject *args,
* result of the operation, if any. If *result is NULL, there is no override.
*/
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyUFunc_CheckOverride(PyUFuncBaseObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
PyObject **result)
{
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/umath/override.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include "numpy/ufuncobject.h"

NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyUFunc_CheckOverride(PyUFuncBaseObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
PyObject **result);
#endif
Loading
0