8000 ENH: `__array_function__` support for `np.lib`, part 2/2 by shoyer · Pull Request #12119 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

ENH: __array_function__ support for np.lib, part 2/2 #12119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions numpy/core/tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def _get_overloaded_args(relevant_args):
return args


def _return_self(self, *args, **kwargs):
return self
def _return_not_implemented(self, *args, **kwargs):
return NotImplemented


class TestGetOverloadedTypesAndArgs(object):
Expand Down Expand Up @@ -45,7 +45,7 @@ def test_ndarray(self):
def test_ndarray_subclasses(self):

class OverrideSub(np.ndarray):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

class NoOverrideSub(np.ndarray):
pass
Expand All @@ -70,7 +70,7 @@ class NoOverrideSub(np.ndarray):
def test_ndarray_and_duck_array(self):

class Other(object):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

array = np.array(1)
other = Other()
Expand All @@ -86,10 +86,10 @@ class Other(object):
def test_ndarray_subclass_and_duck_array(self):

class OverrideSub(np.ndarray):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

class Other(object):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

array = np.array(1)
subarray = np.array(1).view(OverrideSub)
Expand All @@ -103,16 +103,16 @@ class Other(object):
def test_many_duck_arrays(self):

class A(object):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

class B(A):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

class C(A):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

class D(object):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

a = A()
b = B()
Expand All @@ -135,7 +135,7 @@ class TestNDArrayArrayFunction(object):
def test_method(self):

class SubOverride(np.ndarray):
__array_function__ = _return_self
__array_function__ = _return_not_implemented

class NoOverrideSub(np.ndarray):
pass
Expand Down Expand Up @@ -189,7 +189,8 @@ def __array_function__(self, func, types, args, kwargs):
assert_(obj is original)
assert_(func is dispatched_one_arg)
assert_equal(set(types), {MyArray})
assert_equal(args, (original,))
# assert_equal uses the overloaded np.iscomplexobj() internally
assert_(args == (original,))
assert_equal(kwargs, {})

def test_not_implemented(self):
Expand Down
29 changes: 29 additions & 0 deletions numpy/lib/npyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import format
from ._datasource import DataSource
from numpy.core.multiarray import packbits, unpackbits
from numpy.core.overrides import array_function_dispatch
from numpy.core._internal import recursive
from ._iotools import (
LineSplitter, NameValidator, StringConverter, ConverterError,
Expand Down Expand Up @@ -447,6 +448,11 @@ def load(file, mmap_mode=None, allow_pickle=True, fix_imports=True,
fid.close()


def _save_dispatcher(file, arr, allow_pickle=None, fix_imports=None):
return (arr,)


@array_function_dispatch(_save_dispatcher)
def save(file, arr, allow_pickle=True, fix_imports=True):
"""
Save an array to a binary file in NumPy ``.npy`` format.
Expand Down Expand Up @@ -525,6 +531,14 @@ def save(file, arr, allow_pickle=True, fix_imports=True):
fid.close()


def _savez_dispatcher(file, *args, **kwds):
for a in args:
yield a
for v in kwds.values():
yield v


@array_function_dispatch(_savez_dispatcher)
def savez(file, *args, **kwds):
"""
Save several arrays into a single file in uncompressed ``.npz`` format.
Expand Down Expand Up @@ -604,6 +618,14 @@ def savez(file, *args, **kwds):
_savez(file, args, kwds, False)


def _savez_compressed_dispatcher(file, *args, **kwds):
for a in args:
yield a
for v in kwds.values():
yield v


@array_function_dispatch(_savez_compressed_dispatcher)
def savez_compressed(file, *args, **kwds):
"""
Save several arrays into a single file in compressed ``.npz`` format.
Expand Down Expand Up @@ -1154,6 +1176,13 @@ def tobytes_first(x, conv):
return X


def _savetxt_dispatcher(fname, X, fmt=None, delimiter=None, newline=None,
header=None, footer=None, comments=None,
encoding=None):
return (X,)


@array_function_dispatch(_savetxt_dispatcher)
def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
footer='', comments='# ', encoding=None):
"""
Expand Down
52 changes: 52 additions & 0 deletions numpy/lib/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@

from numpy.core import (isscalar, abs, finfo, atleast_1d, hstack, dot, array,
ones)
from numpy.core.overrides import array_function_dispatch
from numpy.lib.twodim_base import diag, vander
from numpy.lib.function_base import trim_zeros
from numpy.lib.type_check import iscomplex, real, imag, mintypecode
from numpy.linalg import eigvals, lstsq, inv


class RankWarning(UserWarning):
"""
Issued by `polyfit` when the Vandermonde matrix is rank deficient.
Expand All @@ -29,6 +31,12 @@ class RankWarning(UserWarning):
"""
pass


def _poly_dispatcher(seq_of_zeros):
return seq_of_zeros


@array_function_dispatch(_poly_dispatcher)
def poly(seq_of_zeros):
"""
Find the coefficients of a polynomial with the given sequence of roots.
Expand Down Expand Up @@ -145,6 +153,12 @@ def poly(seq_of_zeros):

return a


def _roots_dispatcher(p):
return p


@array_function_dispatch(_roots_dispatcher)
def roots(p):
"""
Return the roots of a polynomial with coefficients given in p.
Expand Down Expand Up @@ -229,6 +243,12 @@ def roots(p):
roots = hstack((roots, NX.zeros(trailing_zeros, roots.dtype)))
return roots


def _polyint_dispatcher(p, m=None, k=None):
return (p,)


@array_function_dispatch(_polyint_dispatcher)
def polyint(p, m=1, k=None):
"""
Return an antiderivative (indefinite integral) of a polynomial.
Expand Down Expand Up @@ -322,6 +342,12 @@ def polyint(p, m=1, k=None):
return poly1d(val)
return val


def _polyder_dispatcher(p, m=None):
return (p,)


@array_function_dispatch(_polyder_dispatcher)
def polyder(p, m=1):
"""
Return the derivative of the specified order of a polynomial.
Expand Down Expand Up @@ -390,6 +416,12 @@ def polyder(p, m=1):
val = poly1d(val)
return val


def _polyfit_dispatcher(x, y, deg, rcond=None, full=None, w=None, cov=None):
return (x, y, w)


@array_function_dispatch(_polyfit_dispatcher)
def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
"""
Least squares polynomial fit.
Expand Down Expand Up @@ -610,6 +642,11 @@ def polyfit(x, y, deg, rcond=None, full=False, w=None, cov=False):
return c


def _polyval_dispatcher(p, x):
return (p, x)


@array_function_dispatch(_polyval_dispatcher)
def polyval(p, x):
"""
Evaluate a polynomial at specific values.
Expand Down Expand Up @@ -679,6 +716,12 @@ def polyval(p, x):
y = y * x + p[i]
return y


def _binary_op_dispatcher(a1, a2):
return (a1, a2)


@array_function_dispatch(_binary_op_dispatcher)
def polyadd(a1, a2):
"""
Find the sum of two polynomials.
Expand Down Expand Up @@ -739,6 +782,8 @@ def polyadd(a1, a2):
val = poly1d(val)
return val


@array_function_dispatch(_binary_op_dispatcher)
def polysub(a1, a2):
"""
Difference (subtraction) of two polynomials.
Expand Down Expand Up @@ -786,6 +831,7 @@ def polysub(a1, a2):
return val


@array_function_dispatch(_binary_op_dispatcher)
def polymul(a1, a2):
"""
Find the product of two polynomials.
Expand Down Expand Up @@ -842,6 +888,12 @@ def polymul(a1, a2):
val = poly1d(val)
return val


def _polydiv_dispatcher(u, v):
return (u, v)


@array_function_dispatch(_polydiv_dispatcher)
def polydiv(u, v):
"""
Returns the quotient and remainder of polynomial division.
Expand Down
Loading
0