diff --git a/benchmarks/benchmarks/bench_overrides.py b/benchmarks/benchmarks/bench_overrides.py new file mode 100644 index 000000000000..2cb94c95ce60 --- /dev/null +++ b/benchmarks/benchmarks/bench_overrides.py @@ -0,0 +1,61 @@ +from __future__ import absolute_import, division, print_function + +from .common import Benchmark + +from numpy.core.overrides import array_function_dispatch +import numpy as np + + +def _broadcast_to_dispatcher(array, shape, subok=None): + return (array,) + + +@array_function_dispatch(_broadcast_to_dispatcher) +def mock_broadcast_to(array, shape, subok=False): + pass + + +def _concatenate_dispatcher(arrays, axis=None, out=None): + for array in arrays: + yield array + if out is not None: + yield out + + +@array_function_dispatch(_concatenate_dispatcher) +def mock_concatenate(arrays, axis=0, out=None): + pass + + +class DuckArray(object): + def __array_function__(self, func, types, args, kwargs): + pass + + +class ArrayFunction(Benchmark): + + def setup(self): + self.numpy_array = np.array(1) + self.numpy_arrays = [np.array(1), np.array(2)] + self.many_arrays = 500 * self.numpy_arrays + self.duck_array = DuckArray() + self.duck_arrays = [DuckArray(), DuckArray()] + self.mixed_arrays = [np.array(1), DuckArray()] + + def time_mock_broadcast_to_numpy(self): + mock_broadcast_to(self.numpy_array, ()) + + def time_mock_broadcast_to_duck(self): + mock_broadcast_to(self.duck_array, ()) + + def time_mock_concatenate_numpy(self): + mock_concatenate(self.numpy_arrays, axis=0) + + def time_mock_concatenate_many(self): + mock_concatenate(self.many_arrays, axis=0) + + def time_mock_concatenate_duck(self): + mock_concatenate(self.duck_arrays, axis=0) + + def time_mock_concatenate_mixed(self): + mock_concatenate(self.mixed_arrays, axis=0) diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 33f6d01a89c3..8974f0ce1a9f 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -154,3 +154,18 @@ def _ptp(a, axis=None, out=None, keepdims=False): umr_minimum(a, axis, None, None, keepdims), out ) + +_NDARRAY_ARRAY_FUNCTION = mu.ndarray.__array_function__ + +def _array_function(self, func, types, args, kwargs): + # TODO: rewrite this in C + # Cannot handle items that have __array_function__ other than our own. + for t in types: + if t is not mu.ndarray: + method = getattr(t, '__array_function__', _NDARRAY_ARRAY_FUNCTION) + if method is not _NDARRAY_ARRAY_FUNCTION: + return NotImplemented + + # Arguments contain no overrides, so we can safely call the + # overloaded function again. + return func(*args, **kwargs) diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py new file mode 100644 index 000000000000..c1d5e38643aa --- /dev/null +++ b/numpy/core/overrides.py @@ -0,0 +1,86 @@ +"""Preliminary implementation of NEP-18 + +TODO: rewrite this in C for performance. +""" +import functools +from numpy.core.multiarray import ndarray + + +_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ + + +def get_overloaded_types_and_args(relevant_args): + """Returns a list of arguments on which to call __array_function__. + + __array_function__ implementations should be called in order on the return + values from this function. + """ + # Runtime is O(num_arguments * num_unique_types) + overloaded_types = [] + overloaded_args = [] + for arg in relevant_args: + arg_type = type(arg) + if (arg_type not in overloaded_types and + hasattr(arg_type, '__array_function__')): + + overloaded_types.append(arg_type) + + # By default, insert this argument at the end, but if it is + # subclass of another argument, insert it before that argument. + # This ensures "subclasses before superclasses". + index = len(overloaded_args) + for i, old_arg in enumerate(overloaded_args): + if issubclass(arg_type, type(old_arg)): + index = i + break + overloaded_args.insert(index, arg) + + # Special handling for ndarray. + overloaded_args = [ + arg for arg in overloaded_args + if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION + ] + + return overloaded_types, overloaded_args + + +def array_function_override(overloaded_args, func, types, args, kwargs): + """Call __array_function__ implementations.""" + for overloaded_arg in overloaded_args: + # Note that we're only calling __array_function__ on the *first* + # occurence of each argument type. This is necessary for reasonable + # performance with a possibly long list of overloaded arguments, for + # which each __array_function__ implementation might reasonably need to + # check all argument types. + result = overloaded_arg.__array_function__(func, types, args, kwargs) + + if result is not NotImplemented: + return result + + raise TypeError('no implementation found for {} on types that implement ' + '__array_function__: {}' + .format(func, list(map(type, overloaded_args)))) + + +def array_function_dispatch(dispatcher): + """Wrap a function for dispatch with the __array_function__ protocol.""" + def decorator(func): + @functools.wraps(func) + def new_func(*args, **kwargs): + # Collect array-like arguments. + relevant_arguments = dispatcher(*args, **kwargs) + # Check for __array_function__ methods. + types, overloaded_args = get_overloaded_types_and_args( + relevant_arguments) + # Call overrides, if necessary. + if overloaded_args: + # new_func is the function exposed in NumPy's public API. We + # use it instead of func so __array_function__ implementations + # can do equality/identity comparisons. + return array_function_override( + overloaded_args, new_func, types, args, kwargs) + else: + return func(*args, **kwargs) + + return new_func + return decorator diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 3d2cce5e18f5..6317d6a16a5d 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1021,6 +1021,13 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds) } +static PyObject * +array_function(PyArrayObject *self, PyObject *args, PyObject *kwds) +{ + NPY_FORWARD_NDARRAY_METHOD("_array_function"); +} + + static PyObject * array_copy(PyArrayObject *self, PyObject *args, PyObject *kwds) { @@ -2472,6 +2479,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = { {"__array_ufunc__", (PyCFunction)array_ufunc, METH_VARARGS | METH_KEYWORDS, NULL}, + {"__array_function__", + (PyCFunction)array_function, + METH_VARARGS | METH_KEYWORDS, NULL}, #ifndef NPY_PY3K {"__unicode__", diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py new file mode 100644 index 000000000000..7f6157a5bf09 --- /dev/null +++ b/numpy/core/tests/test_overrides.py @@ -0,0 +1,273 @@ +from __future__ import division, absolute_import, print_function + +import pickle +import sys + +import numpy as np +from numpy.testing import ( + assert_, assert_equal, assert_raises, assert_raises_regex) +from numpy.core.overrides import ( + get_overloaded_types_and_args, array_function_dispatch) + + +def _get_overloaded_args(relevant_args): + types, args = get_overloaded_types_and_args(relevant_args) + return args + + +def _return_self(self, *args, **kwargs): + return self + + +class TestGetOverloadedTypesAndArgs(object): + + def test_ndarray(self): + array = np.array(1) + + types, args = get_overloaded_types_and_args([array]) + assert_equal(set(types), {np.ndarray}) + assert_equal(list(args), []) + + types, args = get_overloaded_types_and_args([array, array]) + assert_equal(len(types), 1) + assert_equal(set(types), {np.ndarray}) + assert_equal(list(args), []) + + types, args = get_overloaded_types_and_args([array, 1]) + assert_equal(set(types), {np.ndarray}) + assert_equal(list(args), []) + + types, args = get_overloaded_types_and_args([1, array]) + assert_equal(set(types), {np.ndarray}) + assert_equal(list(args), []) + + def test_ndarray_subclasses(self): + + class OverrideSub(np.ndarray): + __array_function__ = _return_self + + class NoOverrideSub(np.ndarray): + pass + + array = np.array(1).view(np.ndarray) + override_sub = np.array(1).view(OverrideSub) + no_override_sub = np.array(1).view(NoOverrideSub) + + types, args = get_overloaded_types_and_args([array, override_sub]) + assert_equal(set(types), {np.ndarray, OverrideSub}) + assert_equal(list(args), [override_sub]) + + types, args = get_overloaded_types_and_args([array, no_override_sub]) + assert_equal(set(types), {np.ndarray, NoOverrideSub}) + assert_equal(list(args), []) + + types, args = get_overloaded_types_and_args( + [override_sub, no_override_sub]) + assert_equal(set(types), {OverrideSub, NoOverrideSub}) + assert_equal(list(args), [override_sub]) + + def test_ndarray_and_duck_array(self): + + class Other(object): + __array_function__ = _return_self + + array = np.array(1) + other = Other() + + types, args = get_overloaded_types_and_args([other, array]) + assert_equal(set(types), {np.ndarray, Other}) + assert_equal(list(args), [other]) + + types, args = get_overloaded_types_and_args([array, other]) + assert_equal(set(types), {np.ndarray, Other}) + assert_equal(list(args), [other]) + + def test_ndarray_subclass_and_duck_array(self): + + class OverrideSub(np.ndarray): + __array_function__ = _return_self + + class Other(object): + __array_function__ = _return_self + + array = np.array(1) + subarray = np.array(1).view(OverrideSub) + other = Other() + + assert_equal(_get_overloaded_args([array, subarray, other]), + [subarray, other]) + assert_equal(_get_overloaded_args([array, other, subarray]), + [subarray, other]) + + def test_many_duck_arrays(self): + + class A(object): + __array_function__ = _return_self + + class B(A): + __array_function__ = _return_self + + class C(A): + __array_function__ = _return_self + + class D(object): + __array_function__ = _return_self + + a = A() + b = B() + c = C() + d = D() + + assert_equal(_get_overloaded_args([1]), []) + assert_equal(_get_overloaded_args([a]), [a]) + assert_equal(_get_overloaded_args([a, 1]), [a]) + assert_equal(_get_overloaded_args([a, a, a]), [a]) + assert_equal(_get_overloaded_args([a, d, a]), [a, d]) + assert_equal(_get_overloaded_args([a, b]), [b, a]) + assert_equal(_get_overloaded_args([b, a]), [b, a]) + assert_equal(_get_overloaded_args([a, b, c]), [b, c, a]) + assert_equal(_get_overloaded_args([a, c, b]), [c, b, a]) + + +class TestNDArrayArrayFunction(object): + + def test_method(self): + + class SubOverride(np.ndarray): + __array_function__ = _return_self + + class NoOverrideSub(np.ndarray): + pass + + array = np.array(1) + + def func(): + return 'original' + + result = array.__array_function__( + func=func, types=(np.ndarray,), args=(), kwargs={}) + assert_equal(result, 'original') + + result = array.__array_function__( + func=func, types=(np.ndarray, SubOverride), args=(), kwargs={}) + assert_(result is NotImplemented) + + result = array.__array_function__( + func=func, types=(np.ndarray, NoOverrideSub), args=(), kwargs={}) + assert_equal(result, 'original') + + +# need to define this at the top level to test pickling +@array_function_dispatch(lambda array: (array,)) +def dispatched_one_arg(array): + """Docstring.""" + return 'original' + + +class TestArrayFunctionDispatch(object): + + def test_pickle(self): + roundtripped = pickle.loads(pickle.dumps(dispatched_one_arg)) + assert_(roundtripped is dispatched_one_arg) + + def test_name_and_docstring(self): + assert_equal(dispatched_one_arg.__name__, 'dispatched_one_arg') + if sys.flags.optimize < 2: + assert_equal(dispatched_one_arg.__doc__, 'Docstring.') + + def test_interface(self): + + class MyArray(object): + def __array_function__(self, func, types, args, kwargs): + return (self, func, types, args, kwargs) + + original = MyArray() + (obj, func, types, args, kwargs) = dispatched_one_arg(original) + assert_(obj is original) + assert_(func is dispatched_one_arg) + assert_equal(set(types), {MyArray}) + assert_equal(args, (original,)) + assert_equal(kwargs, {}) + + def test_not_implemented(self): + + class MyArray(object): + def __array_function__(self, func, types, args, kwargs): + return NotImplemented + + array = MyArray() + with assert_raises_regex(TypeError, 'no implementation found'): + dispatched_one_arg(array) + + +def _new_duck_type_and_implements(): + """Create a duck array type and implements functions.""" + HANDLED_FUNCTIONS = {} + + class MyArray(object): + def __array_function__(self, func, types, args, kwargs): + if func not in HANDLED_FUNCTIONS: + return NotImplemented + if not all(issubclass(t, MyArray) for t in types): + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) + + def implements(numpy_function): + """Register an __array_function__ implementations.""" + def decorator(func): + HANDLED_FUNCTIONS[numpy_function] = func + return func + return decorator + + return (MyArray, implements) + + +class TestArrayFunctionImplementation(object): + + def test_one_arg(self): + MyArray, implements = _new_duck_type_and_implements() + + @implements(dispatched_one_arg) + def _(array): + return 'myarray' + + assert_equal(dispatched_one_arg(1), 'original') + assert_equal(dispatched_one_arg(MyArray()), 'myarray') + + def test_optional_args(self): + MyArray, implements = _new_duck_type_and_implements() + + @array_function_dispatch(lambda array, option=None: (array,)) + def func_with_option(array, option='default'): + return option + + @implements(func_with_option) + def my_array_func_with_option(array, new_option='myarray'): + return new_option + + # we don't need to implement every option on __array_function__ + # implementations + assert_equal(func_with_option(1), 'default') + assert_equal(func_with_option(1, option='extra'), 'extra') + assert_equal(func_with_option(MyArray()), 'myarray') + with assert_raises(TypeError): + func_with_option(MyArray(), option='extra') + + # but new options on implementations can't be used + result = my_array_func_with_option(MyArray(), new_option='yes') + assert_equal(result, 'yes') + with assert_raises(TypeError): + func_with_option(MyArray(), new_option='no') + + def test_not_implemented(self): + MyArray, implements = _new_duck_type_and_implements() + + @array_function_dispatch(lambda array: (array,)) + def func(array): + return array + + array = np.array(1) + assert_(func(array) is array) + + with assert_raises_regex(TypeError, 'no implementation found'): + func(MyArray())