diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 3f87a6afed35..d027f493e0c7 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -16,8 +16,8 @@ def _get_overloaded_args(relevant_args): return args -def _return_self(self, *args, **kwargs): - return self +def _return_not_implemented(self, *args, **kwargs): + return NotImplemented class TestGetOverloadedTypesAndArgs(object): @@ -45,7 +45,7 @@ def test_ndarray(self): def test_ndarray_subclasses(self): class OverrideSub(np.ndarray): - __array_function__ = _return_self + __array_function__ = _return_not_implemented class NoOverrideSub(np.ndarray): pass @@ -70,7 +70,7 @@ class NoOverrideSub(np.ndarray): def test_ndarray_and_duck_array(self): class Other(object): - __array_function__ = _return_self + __array_function__ = _return_not_implemented array = np.array(1) other = Other() @@ -86,10 +86,10 @@ class Other(object): def test_ndarray_subclass_and_duck_array(self): class OverrideSub(np.ndarray): - __array_function__ = _return_self + __array_function__ = _return_not_implemented class Other(object): - __array_function__ = _return_self + __array_function__ = _return_not_implemented array = np.array(1) subarray = np.array(1).view(OverrideSub) @@ -103,16 +103,16 @@ class Other(object): def test_many_duck_arrays(self): class A(object): - __array_function__ = _return_self + __array_function__ = _return_not_implemented class B(A): - __array_function__ = _return_self + __array_function__ = _return_not_implemented class C(A): - __array_function__ = _return_self + __array_function__ = _return_not_implemented class D(object): - __array_function__ = _return_self + __array_function__ = _return_not_implemented a = A() b = B() @@ -135,7 +135,7 @@ class TestNDArrayArrayFunction(object): def test_method(self): class SubOverride(np.ndarray): - __array_function__ = _return_self + __array_function__ = _return_not_implemented class NoOverrideSub(np.ndarray): pass @@ -189,7 +189,8 @@ def __array_function__(self, func, types, args, kwargs): assert_(obj is original) assert_(func is dispatched_one_arg) assert_equal(set(types), {MyArray}) - assert_equal(args, (original,)) + # assert_equal uses the overloaded np.iscomplexobj() internally + assert_(args == (original,)) assert_equal(kwargs, {}) def test_not_implemented(self): diff --git a/numpy/lib/npyio.py b/numpy/lib/npyio.py index 62fc9c5b3961..733795671f4b 100644 --- a/numpy/lib/npyio.py +++ b/numpy/lib/npyio.py @@ -12,6 +12,7 @@ from . import format from ._datasource import DataSource from numpy.core.multiarray import packbits, unpackbits +from numpy.core.overrides import array_function_dispatch from numpy.core._internal import recursive from ._iotools import ( LineSplitter, NameValidator, StringConverter, ConverterError, @@ -447,6 +448,11 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True, fid.close() +def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None): + return (arr,) + + +@array_function_dispatch(_save_dispatcher) def save(file, arr, allow_pickle=True, fix_imports=True): """ Save an array to a binary file in NumPy ``.npy`` format. @@ -525,6 +531,14 @@ def save(file, arr, allow_pickle=True, fix_imports=True): fid.close() +def _savez_dispatcher(file, *args, **kwds): + for a in args: + yield a + for v in kwds.values(): + yield v + + +@array_function_dispatch(_savez_dispatcher) def savez(file, *args, **kwds): """ Save several arrays into a single file in uncompressed ``.npz`` format. @@ -604,6 +618,14 @@ def savez(file, *args, **kwds): _savez(file, args, kwds, False) +def _savez_compressed_dispatcher(file, *args, **kwds): + for a in args: + yield a + for v in kwds.values(): + yield v + + +@array_function_dispatch(_savez_compressed_dispatcher) def savez_compressed(file, *args, **kwds): """ Save several arrays into a single file in compressed ``.npz`` format. @@ -1154,6 +1176,13 @@ def tobytes_first(x, conv): return X +def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None, + header=None, footer=None, comments=None, + encoding=None): + return (X,) + + +@array_function_dispatch(_savetxt_dispatcher) def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='', footer='', comments='# ', encoding=None): """ diff --git a/numpy/lib/polynomial.py b/numpy/lib/polynomial.py index 9f3b84732f99..09079db9d299 100644 --- a/numpy/lib/polynomial.py +++ b/numpy/lib/polynomial.py @@ -14,11 +14,13 @@ from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array, ones) +from numpy.core.overrides import array_function_dispatch from numpy.lib.twodim_base import diag, vander from numpy.lib.function_base import trim_zeros from numpy.lib.type_check import iscomplex, real, imag, mintypecode from numpy.linalg import eigvals, lstsq, inv + class RankWarning(UserWarning): """ Issued by `polyfit` when the Vandermonde matrix is rank deficient. @@ -29,6 +31,12 @@ class RankWarning(UserWarning): """ pass + +def _poly_dispatcher(seq_of_zeros): + return seq_of_zeros + + +@array_function_dispatch(_poly_dispatcher) def poly(seq_of_zeros): """ Find the coefficients of a polynomial with the given sequence of roots. @@ -145,6 +153,12 @@ def poly(seq_of_zeros): return a + +def _roots_dispatcher(p): + return p + + +@array_function_dispatch(_roots_dispatcher) def roots(p): """ Return the roots of a polynomial with coefficients given in p. @@ -229,6 +243,12 @@ def roots(p): roots = hstack((roots, NX.zeros(trailing_zeros, roots.dtype))) return roots + +def _polyint_dispatcher(p, m=None, k=None): + return (p,) + + +@array_function_dispatch(_polyint_dispatcher) def polyint(p, m=1, k=None): """ Return an antiderivative (indefinite integral) of a polynomial. @@ -322,6 +342,12 @@ def polyint(p, m=1, k=None): return poly1d(val) return val + +def _polyder_dispatcher(p, m=None): + return (p,) + + +@array_function_dispatch(_polyder_dispatcher) def polyder(p, m=1): """ Return the derivative of the specified order of a polynomial. @@ -390,6 +416,12 @@ def polyder(p, m=1): val = poly1d(val) return val + +def _polyfit_dispatcher(x, y, deg, rcond=None, full=None, w=None, cov=None): + return (x, y, w) + + +@array_function_dispatch(_polyfit_dispatcher) def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): """ Least squares polynomial fit. @@ -610,6 +642,11 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False): return c +def _polyval_dispatcher(p, x): + return (p, x) + + +@array_function_dispatch(_polyval_dispatcher) def polyval(p, x): """ Evaluate a polynomial at specific values. @@ -679,6 +716,12 @@ def polyval(p, x): y = y * x + p[i] return y + +def _binary_op_dispatcher(a1, a2): + return (a1, a2) + + +@array_function_dispatch(_binary_op_dispatcher) def polyadd(a1, a2): """ Find the sum of two polynomials. @@ -739,6 +782,8 @@ def polyadd(a1, a2): val = poly1d(val) return val + +@array_function_dispatch(_binary_op_dispatcher) def polysub(a1, a2): """ Difference (subtraction) of two polynomials. @@ -786,6 +831,7 @@ def polysub(a1, a2): return val +@array_function_dispatch(_binary_op_dispatcher) def polymul(a1, a2): """ Find the product of two polynomials. @@ -842,6 +888,12 @@ def polymul(a1, a2): val = poly1d(val) return val + +def _polydiv_dispatcher(u, v): + return (u, v) + + +@array_function_dispatch(_polydiv_dispatcher) def polydiv(u, v): """ Returns the quotient and remainder of polynomial division. diff --git a/numpy/lib/recfunctions.py b/numpy/lib/recfunctions.py index b6453d5a2d64..53a586f56941 100644 --- a/numpy/lib/recfunctions.py +++ b/numpy/lib/recfunctions.py @@ -14,6 +14,7 @@ from numpy import ndarray, recarray from numpy.ma import MaskedArray from numpy.ma.mrecords import MaskedRecords +from numpy.core.overrides import array_function_dispatch from numpy.lib._iotools import _is_string_like from numpy.compat import basestring @@ -31,6 +32,11 @@ ] +def _recursive_fill_fields_dispatcher(input, output): + return (input, output) + + +@array_function_dispatch(_recursive_fill_fields_dispatcher) def recursive_fill_fields(input, output): """ Fills fields from output with fields from input, @@ -189,6 +195,11 @@ def flatten_descr(ndtype): return tuple(descr) +def _zip_dtype_dispatcher(seqarrays, flatten=None): + return seqarrays + + +@array_function_dispatch(_zip_dtype_dispatcher) def zip_dtype(seqarrays, flatten=False): newdtype = [] if flatten: @@ -205,6 +216,7 @@ def zip_dtype(seqarrays, flatten=False): return np.dtype(newdtype) +@array_function_dispatch(_zip_dtype_dispatcher) def zip_descr(seqarrays, flatten=False): """ Combine the dtype description of a series of arrays. @@ -297,6 +309,11 @@ def _izip_fields(iterable): yield element +def _izip_records_dispatcher(seqarrays, fill_value=None, flatten=None): + return seqarrays + + +@array_function_dispatch(_izip_records_dispatcher) def izip_records(seqarrays, fill_value=None, flatten=True): """ Returns an iterator of concatenated items from a sequence of arrays. @@ -357,6 +374,12 @@ def _fix_defaults(output, defaults=None): return output +def _merge_arrays_dispatcher(seqarrays, fill_value=None, flatten=None, + usemask=None, asrecarray=None): + return seqarrays + + +@array_function_dispatch(_merge_arrays_dispatcher) def merge_arrays(seqarrays, fill_value=-1, flatten=False, usemask=False, asrecarray=False): """ @@ -494,6 +517,11 @@ def merge_arrays(seqarrays, fill_value=-1, flatten=False, return output +def _drop_fields_dispatcher(base, drop_names, usemask=None, asrecarray=None): + return (base,) + + +@array_function_dispatch(_drop_fields_dispatcher) def drop_fields(base, drop_names, usemask=True, asrecarray=False): """ Return a new array with fields in `drop_names` dropped. @@ -583,6 +611,11 @@ def _keep_fields(base, keep_names, usemask=True, asrecarray=False): return _fix_output(output, usemask=usemask, asrecarray=asrecarray) +def _rec_drop_fields_dispatcher(base, drop_names): + return (base,) + + +@array_function_dispatch(_rec_drop_fields_dispatcher) def rec_drop_fields(base, drop_names): """ Returns a new numpy.recarray with fields in `drop_names` dropped. @@ -590,6 +623,11 @@ def rec_drop_fields(base, drop_names): return drop_fields(base, drop_names, usemask=False, asrecarray=True) +def _rename_fields_dispatcher(base, namemapper): + return (base,) + + +@array_function_dispatch(_rename_fields_dispatcher) def rename_fields(base, namemapper): """ Rename the fields from a flexible-datatype ndarray or recarray. @@ -629,6 +667,14 @@ def _recursive_rename_fields(ndtype, namemapper): return base.view(newdtype) +def _append_fields_dispatcher(base, names, data, dtypes=None, + fill_value=None, usemask=None, asrecarray=None): + yield base + for d in data: + yield d + + +@array_function_dispatch(_append_fields_dispatcher) def append_fields(base, names, data, dtypes=None, fill_value=-1, usemask=True, asrecarray=False): """ @@ -699,6 +745,13 @@ def append_fields(base, names, data, dtypes=None, return _fix_output(output, usemask=usemask, asrecarray=asrecarray) +def _rec_append_fields_dispatcher(base, names, data, dtypes=None): + yield base + for d in data: + yield d + + +@array_function_dispatch(_rec_append_fields_dispatcher) def rec_append_fields(base, names, data, dtypes=None): """ Add new fields to an existing array. @@ -732,6 +785,12 @@ def rec_append_fields(base, names, data, dtypes=None): return append_fields(base, names, data=data, dtypes=dtypes, asrecarray=True, usemask=False) + +def _repack_fields_dispatcher(a, align=None, recurse=None): + return (a,) + + +@array_function_dispatch(_repack_fields_dispatcher) def repack_fields(a, align=False, recurse=False): """ Re-pack the fields of a structured array or dtype in memory. @@ -811,6 +870,13 @@ def repack_fields(a, align=False, recurse=False): dt = np.dtype(fieldinfo, align=align) return np.dtype((a.type, dt)) + +def _stack_arrays_dispatcher(arrays, defaults=None, usemask=None, + asrecarray=None, autoconvert=None): + return arrays + + +@array_function_dispatch(_stack_arrays_dispatcher) def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False, autoconvert=False): """ @@ -897,6 +963,12 @@ def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False, usemask=usemask, asrecarray=asrecarray) +def _find_duplicates_dispatcher( + a, key=None, ignoremask=None, return_index=None): + return (a,) + + +@array_function_dispatch(_find_duplicates_dispatcher) def find_duplicates(a, key=None, ignoremask=True, return_index=False): """ Find the duplicates in a structured array along a given key @@ -951,8 +1023,15 @@ def find_duplicates(a, key=None, ignoremask=True, return_index=False): return duplicates +def _join_by_dispatcher( + key, r1, r2, jointype=None, r1postfix=None, r2postfix=None, + defaults=None, usemask=None, asrecarray=None): + return (r1, r2) + + +@array_function_dispatch(_join_by_dispatcher) def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', - defaults=None, usemask=True, asrecarray=False): + defaults=None, usemask=True, asrecarray=False): """ Join arrays `r1` and `r2` on key `key`. @@ -1130,6 +1209,13 @@ def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', return _fix_output(_fix_defaults(output, defaults), **kwargs) +def _rec_join_dispatcher( + key, r1, r2, jointype=None, r1postfix=None, r2postfix=None, + defaults=None): + return (r1, r2) + + +@array_function_dispatch(_rec_join_dispatcher) def rec_join(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2', defaults=None): """ diff --git a/numpy/lib/scimath.py b/numpy/lib/scimath.py index f1838fee6d37..9ca006841136 100644 --- a/numpy/lib/scimath.py +++ b/numpy/lib/scimath.py @@ -20,6 +20,7 @@ import numpy.core.numeric as nx import numpy.core.numerictypes as nt from numpy.core.numeric import asarray, any +from numpy.core.overrides import array_function_dispatch from numpy.lib.type_check import isreal @@ -94,6 +95,7 @@ def _tocomplex(arr): else: return arr.astype(nt.cdouble) + def _fix_real_lt_zero(x): """Convert `x` to complex if it has real, negative components. @@ -121,6 +123,7 @@ def _fix_real_lt_zero(x): x = _tocomplex(x) return x + def _fix_int_lt_zero(x): """Convert `x` to double if it has real, negative components. @@ -147,6 +150,7 @@ def _fix_int_lt_zero(x): x = x * 1.0 return x + def _fix_real_abs_gt_1(x): """Convert `x` to complex if it has real components x_i with abs(x_i)>1. @@ -173,6 +177,12 @@ def _fix_real_abs_gt_1(x): x = _tocomplex(x) return x + +def _unary_dispatcher(x): + return (x,) + + +@array_function_dispatch(_unary_dispatcher) def sqrt(x): """ Compute the square root of x. @@ -215,6 +225,8 @@ def sqrt(x): x = _fix_real_lt_zero(x) return nx.sqrt(x) + +@array_function_dispatch(_unary_dispatcher) def log(x): """ Compute the natural logarithm of `x`. @@ -261,6 +273,8 @@ def log(x): x = _fix_real_lt_zero(x) return nx.log(x) + +@array_function_dispatch(_unary_dispatcher) def log10(x): """ Compute the logarithm base 10 of `x`. @@ -309,6 +323,12 @@ def log10(x): x = _fix_real_lt_zero(x) return nx.log10(x) + +def _logn_dispatcher(n, x): + return (n, x,) + + +@array_function_dispatch(_logn_dispatcher) def logn(n, x): """ Take log base n of x. @@ -318,8 +338,8 @@ def logn(n, x): Parameters ---------- - n : int - The base in which the log is taken. + n : array_like + The integer base(s) in which the log is taken. x : array_like The value(s) whose log base `n` is (are) required. @@ -343,6 +363,8 @@ def logn(n, x): n = _fix_real_lt_zero(n) return nx.log(x)/nx.log(n) + +@array_function_dispatch(_unary_dispatcher) def log2(x): """ Compute the logarithm base 2 of `x`. @@ -389,6 +411,12 @@ def log2(x): x = _fix_real_lt_zero(x) return nx.log2(x) + +def _power_dispatcher(x, p): + return (x, p) + + +@array_function_dispatch(_power_dispatcher) def power(x, p): """ Return x to the power p, (x**p). @@ -432,6 +460,8 @@ def power(x, p): p = _fix_int_lt_zero(p) return nx.power(x, p) + +@array_function_dispatch(_unary_dispatcher) def arccos(x): """ Compute the inverse cosine of x. @@ -475,6 +505,8 @@ def arccos(x): x = _fix_real_abs_gt_1(x) return nx.arccos(x) + +@array_function_dispatch(_unary_dispatcher) def arcsin(x): """ Compute the inverse sine of x. @@ -519,6 +551,8 @@ def arcsin(x): x = _fix_real_abs_gt_1(x) return nx.arcsin(x) + +@array_function_dispatch(_unary_dispatcher) def arctanh(x): """ Compute the inverse hyperbolic tangent of `x`. diff --git a/numpy/lib/shape_base.py b/numpy/lib/shape_base.py index 66f53473485d..e8d43958a131 100644 --- a/numpy/lib/shape_base.py +++ b/numpy/lib/shape_base.py @@ -8,6 +8,7 @@ ) from numpy.core.fromnumeric import product, reshape, transpose from numpy.core.multiarray import normalize_axis_index +from numpy.core.overrides import array_function_dispatch from numpy.core import vstack, atleast_3d from numpy.lib.index_tricks import ndindex from numpy.matrixlib.defmatrix import matrix # this raises all the right alarm bells @@ -44,6 +45,11 @@ def _make_along_axis_idx(arr_shape, indices, axis): return tuple(fancy_index) +def _take_along_axis_dispatcher(arr, indices, axis): + return (arr, indices) + + +@array_function_dispatch(_take_along_axis_dispatcher) def take_along_axis(arr, indices, axis): """ Take values from the input array by matching 1d index and data slices. @@ -160,6 +166,11 @@ def take_along_axis(arr, indices, axis): return arr[_make_along_axis_idx(arr_shape, indices, axis)] +def _put_along_axis_dispatcher(arr, indices, values, axis): + return (arr, indices, values) + + +@array_function_dispatch(_put_along_axis_dispatcher) def put_along_axis(arr, indices, values, axis): """ Put values into the destination array by matching 1d index and data slices. @@ -245,6 +256,11 @@ def put_along_axis(arr, indices, values, axis): arr[_make_along_axis_idx(arr_shape, indices, axis)] = values +def _apply_along_axis_dispatcher(func1d, axis, arr, *args, **kwargs): + return (arr,) + + +@array_function_dispatch(_apply_along_axis_dispatcher) def apply_along_axis(func1d, axis, arr, *args, **kwargs): """ Apply a function to 1-D slices along the given axis. @@ -392,6 +408,11 @@ def apply_along_axis(func1d, axis, arr, *args, **kwargs): return res.__array_wrap__(out_arr) +def _apply_over_axes_dispatcher(func, a, axes): + return (a,) + + +@array_function_dispatch(_apply_over_axes_dispatcher) def apply_over_axes(func, a, axes): """ Apply a function repeatedly over multiple axes. @@ -474,9 +495,15 @@ def apply_over_axes(func, a, axes): val = res else: raise ValueError("function is not returning " - "an array of the correct shape") + "an array of the correct shape") return val + +def _expand_dims_dispatcher(a, axis): + return (a,) + + +@array_function_dispatch(_expand_dims_dispatcher) def expand_dims(a, axis): """ Expand the shape of an array. @@ -554,8 +581,15 @@ def expand_dims(a, axis): # axis = normalize_axis_index(axis, a.ndim + 1) return a.reshape(shape[:axis] + (1,) + shape[axis:]) + row_stack = vstack + +def _column_stack_dispatcher(tup): + return tup + + +@array_function_dispatch(_column_stack_dispatcher) def column_stack(tup): """ Stack 1-D arrays as columns into a 2-D array. @@ -597,6 +631,12 @@ def column_stack(tup): arrays.append(arr) return _nx.concatenate(arrays, 1) + +def _dstack_dispatcher(tup): + return tup + + +@array_function_dispatch(_dstack_dispatcher) def dstack(tup): """ Stack arrays in sequence depth wise (along third axis). @@ -649,6 +689,7 @@ def dstack(tup): """ return _nx.concatenate([atleast_3d(_m) for _m in tup], 2) + def _replace_zero_by_x_arrays(sub_arys): for i in range(len(sub_arys)): if _nx.ndim(sub_arys[i]) == 0: @@ -657,6 +698,12 @@ def _replace_zero_by_x_arrays(sub_arys): sub_arys[i] = _nx.empty(0, dtype=sub_arys[i].dtype) return sub_arys + +def _array_split_dispatcher(ary, indices_or_sections, axis=None): + return (ary, indices_or_sections) + + +@array_function_dispatch(_array_split_dispatcher) def array_split(ary, indices_or_sections, axis=0): """ Split an array into multiple sub-arrays. @@ -712,7 +759,12 @@ def array_split(ary, indices_or_sections, axis=0): return sub_arys -def split(ary,indices_or_sections,axis=0): +def _split_dispatcher(ary, indices_or_sections, axis=None): + return (ary, indices_or_sections) + + +@array_function_dispatch(_split_dispatcher) +def split(ary, indices_or_sections, axis=0): """ Split an array into multiple sub-arrays. @@ -789,6 +841,12 @@ def split(ary,indices_or_sections,axis=0): res = array_split(ary, indices_or_sections, axis) return res + +def _hvdsplit_dispatcher(ary, indices_or_sections): + return (ary, indices_or_sections) + + +@array_function_dispatch(_hvdsplit_dispatcher) def hsplit(ary, indices_or_sections): """ Split an array into multiple sub-arrays horizontally (column-wise). @@ -851,6 +909,8 @@ def hsplit(ary, indices_or_sections): else: return split(ary, indices_or_sections, 0) + +@array_function_dispatch(_hvdsplit_dispatcher) def vsplit(ary, indices_or_sections): """ Split an array into multiple sub-arrays vertically (row-wise). @@ -902,6 +962,8 @@ def vsplit(ary, indices_or_sections): raise ValueError('vsplit only works on arrays of 2 or more dimensions') return split(ary, indices_or_sections, 0) + +@array_function_dispatch(_hvdsplit_dispatcher) def dsplit(ary, indices_or_sections): """ Split array into multiple sub-arrays along the 3rd axis (depth). @@ -971,6 +1033,12 @@ def get_array_wrap(*args): return wrappers[-1][-1] return None + +def _kron_dispatcher(a, b): + return (a, b) + + +@array_function_dispatch(_kron_dispatcher) def kron(a, b): """ Kronecker product of two arrays. @@ -1070,6 +1138,11 @@ def kron(a, b): return result +def _tile_dispatcher(A, reps): + return (A, reps) + + +@array_function_dispatch(_tile_dispatcher) def tile(A, reps): """ Construct an array by repeating A the number of times given by reps. diff --git a/numpy/lib/stride_tricks.py b/numpy/lib/stride_tricks.py index ca13738c1f37..3115c5e37f62 100644 --- a/numpy/lib/stride_tricks.py +++ b/numpy/lib/stride_tricks.py @@ -8,6 +8,7 @@ from __future__ import division, absolute_import, print_function import numpy as np +from numpy.core.overrides import array_function_dispatch __all__ = ['broadcast_to', 'broadcast_arrays'] @@ -135,6 +136,11 @@ def _broadcast_to(array, shape, subok, readonly): return result +def _broadcast_to_dispatcher(array, shape, subok=None): + return (array,) + + +@array_function_dispatch(_broadcast_to_dispatcher) def broadcast_to(array, shape, subok=False): """Broadcast an array to a new shape. @@ -195,6 +201,11 @@ def _broadcast_shape(*args): return b.shape +def _broadcast_arrays_dispatcher(*args, **kwargs): + return args + + +@array_function_dispatch(_broadcast_arrays_dispatcher) def broadcast_arrays(*args, **kwargs): """ Broadcast any number of arrays against each other. diff --git a/numpy/lib/twodim_base.py b/numpy/lib/twodim_base.py index 98efba1911e1..aff6cddde0b5 100644 --- a/numpy/lib/twodim_base.py +++ b/numpy/lib/twodim_base.py @@ -8,6 +8,7 @@ asarray, where, int8, int16, int32, int64, empty, promote_types, diagonal, nonzero ) +from numpy.core.overrides import array_function_dispatch from numpy.core import iinfo, transpose @@ -33,6 +34,11 @@ def _min_int(low, high): return int64 +def _flip_dispatcher(m): + return (m,) + + +@array_function_dispatch(_flip_dispatcher) def fliplr(m): """ Flip array in the left/right direction. @@ -83,6 +89,7 @@ def fliplr(m): return m[:, ::-1] +@array_function_dispatch(_flip_dispatcher) def flipud(m): """ Flip array in the up/down direction. @@ -194,6 +201,11 @@ def eye(N, M=None, k=0, dtype=float, order='C'): return m +def _diag_dispatcher(v, k=None): + return (v,) + + +@array_function_dispatch(_diag_dispatcher) def diag(v, k=0): """ Extract a diagonal or construct a diagonal array. @@ -265,6 +277,7 @@ def diag(v, k=0): raise ValueError("Input must be 1- or 2-d.") +@array_function_dispatch(_diag_dispatcher) def diagflat(v, k=0): """ Create a two-dimensional array with the flattened input as a diagonal. @@ -373,6 +386,11 @@ def tri(N, M=None, k=0, dtype=float): return m +def _trilu_dispatcher(m, k=None): + return (m,) + + +@array_function_dispatch(_trilu_dispatcher) def tril(m, k=0): """ Lower triangle of an array. @@ -411,6 +429,7 @@ def tril(m, k=0): return where(mask, m, zeros(1, m.dtype)) +@array_function_dispatch(_trilu_dispatcher) def triu(m, k=0): """ Upper triangle of an array. @@ -439,7 +458,12 @@ def triu(m, k=0): return where(mask, zeros(1, m.dtype), m) +def _vander_dispatcher(x, N=None, increasing=None): + return (x,) + + # Originally borrowed from John Hunter and matplotlib +@array_function_dispatch(_vander_dispatcher) def vander(x, N=None, increasing=False): """ Generate a Vandermonde matrix. @@ -530,6 +554,12 @@ def vander(x, N=None, increasing=False): return v +def _histogram2d_dispatcher(x, y, bins=None, range=None, normed=None, + weights=None, density=None): + return (x, y, bins, weights) + + +@array_function_dispatch(_histogram2d_dispatcher) def histogram2d(x, y, bins=10, range=None, normed=None, weights=None, density=None): """ @@ -812,6 +842,11 @@ def tril_indices(n, k=0, m=None): return nonzero(tri(n, m, k=k, dtype=bool)) +def _trilu_indices_form_dispatcher(arr, k=None): + return (arr,) + + +@array_function_dispatch(_trilu_indices_form_dispatcher) def tril_indices_from(arr, k=0): """ Return the indices for the lower-triangle of arr. @@ -922,6 +957,7 @@ def triu_indices(n, k=0, m=None): return nonzero(~tri(n, m, k=k-1, dtype=bool)) +@array_function_dispatch(_trilu_indices_form_dispatcher) def triu_indices_from(arr, k=0): """ Return the indices for the upper-triangle of arr. diff --git a/numpy/lib/type_check.py b/numpy/lib/type_check.py index 603da8567646..5f74d3ca211e 100644 --- a/numpy/lib/type_check.py +++ b/numpy/lib/type_check.py @@ -11,6 +11,7 @@ import numpy.core.numeric as _nx from numpy.core.numeric import asarray, asanyarray, array, isnan, zeros +from numpy.core.overrides import array_function_dispatch from .ufunclike import isneginf, isposinf _typecodes_by_elsize = 'GDFgdfQqLlIiHhBb?' @@ -104,6 +105,11 @@ def asfarray(a, dtype=_nx.float_): return asarray(a, dtype=dtype) +def _real_dispatcher(val): + return (val,) + + +@array_function_dispatch(_real_dispatcher) def real(val): """ Return the real part of the complex argument. @@ -145,6 +151,11 @@ def real(val): return asanyarray(val).real +def _imag_dispatcher(val): + return (val,) + + +@array_function_dispatch(_imag_dispatcher) def imag(val): """ Return the imaginary part of the complex argument. @@ -183,6 +194,11 @@ def imag(val): return asanyarray(val).imag +def _is_type_dispatcher(x): + return (x,) + + +@array_function_dispatch(_is_type_dispatcher) def iscomplex(x): """ Returns a bool array, where True if input element is complex. @@ -218,6 +234,8 @@ def iscomplex(x): res = zeros(ax.shape, bool) return res[()] # convert to scalar if needed + +@array_function_dispatch(_is_type_dispatcher) def isreal(x): """ Returns a bool array, where True if input element is real. @@ -248,6 +266,8 @@ def isreal(x): """ return imag(x) == 0 + +@array_function_dispatch(_is_type_dispatcher) def iscomplexobj(x): """ Check for a complex type or an array of complex numbers. @@ -288,6 +308,7 @@ def iscomplexobj(x): return issubclass(type_, _nx.complexfloating) +@array_function_dispatch(_is_type_dispatcher) def isrealobj(x): """ Return True if x is a not complex type or an array of complex numbers. @@ -329,6 +350,12 @@ def _getmaxmin(t): f = getlimits.finfo(t) return f.max, f.min + +def _nan_to_num_dispatcher(x, copy=None): + return (x,) + + +@array_function_dispatch(_nan_to_num_dispatcher) def nan_to_num(x, copy=True): """ Replace NaN with zero and infinity with large finite numbers. @@ -411,7 +438,12 @@ def nan_to_num(x, copy=True): #----------------------------------------------------------------------------- -def real_if_close(a,tol=100): +def _real_if_close_dispatcher(a, tol=None): + return (a,) + + +@array_function_dispatch(_real_if_close_dispatcher) +def real_if_close(a, tol=100): """ If complex input returns a real array if complex parts are close to zero. @@ -466,6 +498,11 @@ def real_if_close(a,tol=100): return a +def _asscalar_dispatcher(a): + return (a,) + + +@array_function_dispatch(_asscalar_dispatcher) def asscalar(a): """ Convert an array of size 1 to its scalar equivalent. @@ -586,6 +623,13 @@ def typename(char): _nx.csingle: 1, _nx.cdouble: 2, _nx.clongdouble: 3} + + +def _common_type_dispatcher(*arrays): + return arrays + + +@array_function_dispatch(_common_type_dispatcher) def common_type(*arrays): """ Return a scalar type which is common to the input arrays. diff --git a/numpy/lib/ufunclike.py b/numpy/lib/ufunclike.py index 6259c544530b..1c6f0417a6ea 100644 --- a/numpy/lib/ufunclike.py +++ b/numpy/lib/ufunclike.py @@ -8,6 +8,7 @@ __all__ = ['fix', 'isneginf', 'isposinf'] import numpy.core.numeric as nx +from numpy.core.overrides import array_function_dispatch import warnings import functools @@ -37,7 +38,30 @@ def func(x, out=None, **kwargs): return func +def _fix_out_named_y(f): + """ + Allow the out argument to be passed as the name `y` (deprecated) + + This decorator should only be used if _deprecate_out_named_y is used on + a corresponding dispatcher fucntion. + """ + @functools.wraps(f) + def func(x, out=None, **kwargs): + if 'y' in kwargs: + # we already did error checking in _deprecate_out_named_y + out = kwargs.pop('y') + return f(x, out=out, **kwargs) + + return func + + @_deprecate_out_named_y +def _dispatcher(x, out=None): + return (x, out) + + +@array_function_dispatch(_dispatcher, verify=False) +@_fix_out_named_y def fix(x, out=None): """ Round to nearest integer towards zero. @@ -83,7 +107,8 @@ def fix(x, out=None): return res -@_deprecate_out_named_y +@array_function_dispatch(_dispatcher, verify=False) +@_fix_out_named_y def isposinf(x, out=None): """ Test element-wise for positive infinity, return result as bool array. @@ -151,7 +176,8 @@ def isposinf(x, out=None): return nx.logical_and(is_inf, signbit, out) -@_deprecate_out_named_y +@array_function_dispatch(_dispatcher, verify=False) +@_fix_out_named_y def isneginf(x, out=None): """ Test element-wise for negative infinity, return result as bool array. diff --git a/numpy/testing/_private/utils.py b/numpy/testing/_private/utils.py index a3832fcdef52..b37dac69dd20 100644 --- a/numpy/testing/_private/utils.py +++ b/numpy/testing/_private/utils.py @@ -19,7 +19,7 @@ import pprint from numpy.core import( - float32, empty, arange, array_repr, ndarray, isnat, array) + bool_, float32, empty, arange, array_repr, ndarray, isnat, array) from numpy.lib.utils import deprecate if sys.version_info[0] >= 3: @@ -352,7 +352,7 @@ def assert_equal(actual, desired, err_msg='', verbose=True): # XXX: catch ValueError for subclasses of ndarray where iscomplex fail try: usecomplex = iscomplexobj(actual) or iscomplexobj(desired) - except ValueError: + except (ValueError, TypeError): usecomplex = False if usecomplex: @@ -705,15 +705,20 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): at the same locations. """ - # Both the != True comparison here and the cast to bool_ at the end are - # done to deal with `masked`, which cannot be compared usefully, and - # for which np.all yields masked. The use of the function np.all is - # for back compatibility with ndarray subclasses that changed the - # return values of the all method. We are not committed to supporting - # such subclasses, but some used to work. x_id = func(x) y_id = func(y) - if npall(x_id == y_id) != True: + # We include work-arounds here to handle three types of slightly + # pathological ndarray subclasses: + # (1) all() on `masked` array scalars can return masked arrays, so we + # use != True + # (2) __eq__ on some ndarray subclasses returns Python booleans + # instead of element-wise comparisons, so we cast to bool_() and + # use isinstance(..., bool) checks + # (3) subclasses with bare-bones __array_function__ implemenations may + # not implement np.all(), so favor using the .all() method + # We are not committed to supporting such subclasses, but it's nice to + # support them if possible. + if bool_(x_id == y_id).all() != True: msg = build_err_msg([x, y], err_msg + '\nx and y %s location mismatch:' % (hasval), verbose=verbose, header=header, @@ -721,9 +726,9 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'): raise AssertionError(msg) # If there is a scalar, then here we know the array has the same # flag as it everywhere, so we should return the scalar flag. - if x_id.ndim == 0: + if isinstance(x_id, bool) or x_id.ndim == 0: return bool_(x_id) - elif y_id.ndim == 0: + elif isinstance(x_id, bool) or y_id.ndim == 0: return bool_(y_id) else: return y_id diff --git a/numpy/testing/tests/test_utils.py b/numpy/testing/tests/test_utils.py index e0d3414f795d..8099db0b78af 100644 --- a/numpy/testing/tests/test_utils.py +++ b/numpy/testing/tests/test_utils.py @@ -158,6 +158,44 @@ def test_masked_nan_inf(self): self._test_equal(a, b) self._test_equal(b, a) + def test_subclass_that_overrides_eq(self): + # While we cannot guarantee testing functions will always work for + # subclasses, the tests should ideally rely only on subclasses having + # comparison operators, not on them being able to store booleans + # (which, e.g., astropy Quantity cannot usefully do). See gh-8452. + class MyArray(np.ndarray): + def __eq__(self, other): + return bool(np.equal(self, other).all()) + + def __ne__(self, other): + return not self == other + + a = np.array([1., 2.]).view(MyArray) + b = np.array([2., 3.]).view(MyArray) + assert_(type(a == a), bool) + assert_(a == a) + assert_(a != b) + self._test_equal(a, a) + self._test_not_equal(a, b) + self._test_not_equal(b, a) + + def test_subclass_that_does_not_implement_npall(self): + # While we cannot guarantee testing functions will always work for + # subclasses, the tests should ideally rely only on subclasses having + # comparison operators, not on them being able to store booleans + # (which, e.g., astropy Quantity cannot usefully do). See gh-8452. + class MyArray(np.ndarray): + def __array_function__(self, *args, **kwargs): + return NotImplemented + + a = np.array([1., 2.]).view(MyArray) + b = np.array([2., 3.]).view(MyArray) + with assert_raises(TypeError): + np.all(a) + self._test_equal(a, a) + self._test_not_equal(a, b) + self._test_not_equal(b, a) + class TestBuildErrorMessage(object):