8000 API: Add Array API aliases (math, bitwise, linalg, misc) [Array API] by mtsokol · Pull Request #25086 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

API: Add Array API aliases (math, bitwise, linalg, misc) [Array API] #25086

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions doc/release/upcoming_changes/25086.new_feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
Array API compatible functions' aliases
---------------------------------------

13 aliases for existing functions were added to improve compatibility with the Array API standard:

* Trigonometry: ``acos``, ``acosh``, ``asin``, ``asinh``, ``atan``, ``atanh``, ``atan2``.

* Bitwise: ``bitwise_left_shift``, ``bitwise_invert``, ``bitwise_right_shift``.

* Misc: ``concat``, ``permute_dims``, ``pow``.

* linalg: ``tensordot``, ``matmul``.
56 changes: 0 additions & 56 deletions doc/source/reference/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,66 +71,10 @@ The following functions are named differently in the array API
* - Array API name
- NumPy namespace name
- Notes
* - ``acos``
- ``arccos``
-
* - ``acosh``
- ``arccosh``
-
* - ``asin``
- ``arcsin``
-
* - ``asinh``
- ``arcsinh``
-
* - ``atan``
- ``arctan``
-
* - ``atan2``
- ``arctan2``
-
* - ``atanh``
- ``arctanh``
-
* - ``bitwise_left_shift``
- ``left_shift``
-
* - ``bitwise_invert``
- ``invert``
-
* - ``bitwise_right_shift``
- ``right_shift``
-
* - ``bool``
- ``bool_``
- This is **breaking** because ``np.bool`` is currently a deprecated
alias for the built-in ``bool``.
* - ``concat``
- ``concatenate``
-
* - ``matrix_norm`` and ``vector_norm``
- ``norm``
- ``matrix_norm`` and ``vector_norm`` each do a limited subset of what
``np.norm`` does.
* - ``permute_dims``
- ``transpose``
- Unlike ``np.transpose``, the ``axis`` keyword-argument to
``permute_dims`` is required.
* - ``pow``
- ``power``
-


``linalg`` Namespace Differences
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

These functions are in the ``linalg`` sub-namespace in the array API, but are
only in the top-level namespace in NumPy:

- ``matmul`` (*)
- ``tensordot`` (*)

(*): These functions are also in the top-level namespace in the array API.

Keyword Argument Renames
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions doc/source/reference/routines.array-manipulation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Transpose-like operations
swapaxes
ndarray.T
transpose
permute_dims

Changing number of dimensions
=============================
Expand Down Expand Up @@ -66,6 +67,7 @@ Joining arrays
:toctree: generated/

concatenate
concat
stack
block
vstack
Expand Down
3 changes: 3 additions & 0 deletions doc/source/reference/routines.bitwise.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ Elementwise bit operations
bitwise_or
bitwise_xor
invert
bitwise_invert
left_shift
bitwise_left_shift
right_shift
bitwise_right_shift

Bit packing
-----------
Expand Down
2 changes: 2 additions & 0 deletions doc/source/reference/routines.linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ Matrix and vector products
inner
outer
matmul
linalg.matmul (Array API compatible location)
tensordot
linalg.tensordot (Array API compatible location)
einsum
einsum_path
linalg.matrix_power
Expand Down
8 changes: 8 additions & 0 deletions doc/source/reference/routines.math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@ Trigonometric functions
cos
tan
arcsin
asin
arccos
acos
arctan
atan
hypot
arctan2
atan2
degrees
radians
unwrap
Expand All @@ -31,8 +35,11 @@ Hyperbolic functions
cosh
tanh
arcsinh
asinh
arccosh
acosh
arctanh
atanh

Rounding
--------
Expand Down Expand Up @@ -120,6 +127,7 @@ Arithmetic operations
multiply
divide
power
pow
subtract
true_divide
floor_divide
Expand Down
61 changes: 31 additions & 30 deletions numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,24 +121,25 @@
from . import _core
from ._core import (
False_, ScalarType, True_, _get_promotion_state, _no_nep50_warning,
_set_promotion_state, abs, absolute, add, all, allclose, alltrue,
amax, amin, any, arange, arccos, arccosh, arcsin, arcsinh, arctan,
arctan2, arctanh, argmax, argmin, argpartition, argsort, argwhere,
around, array, array2string, array_equal, array_equiv, array_repr,
array_str, asanyarray, asarray, ascontiguousarray, asfortranarray,
astype, atleast_1d, atleast_2d, atleast_3d, base_repr, binary_repr,
bitwise_and, bitwise_count, bitwise_not, bitwise_or, bitwise_xor,
block, bool, bool_, broadcast, busday_count, busday_offset,
busdaycalendar, byte, bytes_, can_cast, cbrt, cdouble, ceil,
character, choose, clip, clongdouble, complex128, complex64,
complexfloating, compress, concatenate, conj, conjugate, convolve,
copysign, copyto, correlate, cos, cosh, count_nonzero, cross, csingle,
cumprod, cumproduct, cumsum, datetime64, datetime_as_string,
datetime_data, deg2rad, degrees, diagonal, divide, divmod, dot,
double, dtype, e, einsum, einsum_path, empty, empty_like, equal,
errstate, euler_gamma, exp, exp2, expm1, fabs, finfo, flatiter,
flatnonzero, flexible, float16, float32, float64, float_power,
floating, floor, floor_divide, fmax, fmin, fmod,
_set_promotion_state, abs, absolute, acos, acosh, add, all, allclose,
alltrue, amax, amin, any, arange, arccos, arccosh, arcsin, arcsinh,
arctan, arctan2, arctanh, argmax, argmin, argpartition, argsort,
argwhere, around, array, array2string, array_equal, array_equiv,
array_repr, array_str, asanyarray, asarray, ascontiguousarray,
asfortranarray, asin, asinh, atan, atanh, atan2, astype, atleast_1d,
atleast_2d, atleast_3d, base_repr, binary_repr, bitwise_and,
bitwise_count, bitwise_invert, bitwise_left_shift, bitwise_not,
bitwise_or, bitwise_right_shift, bitwise_xor, block, bool, bool_,
broadcast, busday_count, busday_offset, busdaycalendar, byte, bytes_,
can_cast, cbrt, cdouble, ceil, character, choose, clip, clongdouble,
complex128, complex64, complexfloating, compress, concat, concatenate,
conj, conjugate, convolve, copysign, copyto, correlate, cos, cosh,
count_nonzero, cross, csingle, cumprod, cu 10000 mproduct, cumsum,
datetime64, datetime_as_string, datetime_data, deg2rad, degrees,
diagonal, divide, divmod, dot, double, dtype, e, einsum, einsum_path,
empty, empty_like, equal, errstate, euler_gamma, exp, exp2, expm1,
fabs, finfo, flatiter, flatnonzero, flexible, float16, float32,
float64, float_power, floating, floor, floor_divide, fmax, fmin, fmod,
format_float_positional, format_float_scientific, frexp, from_dlpack,
frombuffer, fromfile, fromfunction, fromiter, frompyfunc, fromstring,
full, full_like, gcd, generic, geomspace, get_printoptions,
Expand All @@ -153,18 +154,18 @@
may_share_memory, mean, memmap, min, min_scalar_type, minimum, mod,
modf, moveaxis, multiply, nan, ndarray, ndim, nditer, negative,
nested_iters, newaxis, nextafter, nonzero, not_equal, number, object_,
ones, ones_like, outer, partition, pi, positive, power, printoptions,
prod, product, promote_types, ptp, put, putmask, rad2deg, radians,
ravel, recarray, reciprocal, record, remainder, repeat, require,
reshape, resize, result_type, right_shift, rint, roll, rollaxis,
round, sctypeDict, searchsorted, set_printoptions, setbufsize, seterr,
seterrcall, shape, shares_memory, short, sign, signbit, signedinteger,
sin, single, sinh, size, sometrue, sort, spacing, sqrt, square,
squeeze, stack, std, str_, subtract, sum, swapaxes, take, tan, tanh,
tensordot, timedelta64, trace, transpose, true_divide, trunc,
typecodes, ubyte, ufunc, uint, uint16, uint32, uint64, uint8, uintc,
uintp, ulong, ulonglong, unsignedinteger, ushort, var, vdot, void,
vstack, where, zeros, zeros_like
ones, ones_like, outer, partition, permute_dims, pi, positive, pow,
power, printoptions, prod, product, promote_types, ptp, put, putmask,
rad2deg, radians, ravel, recarray, reciprocal, record, remainder,
repeat, require, reshape, resize, result_type, right_shift, rint,
roll, rollaxis, round, sctypeDict, searchsorted, set_printoptions,
setbufsize, seterr, seterrcall, shape, shares_memory, short, sign,
signbit, signedinteger, sin, single, sinh, size, sometrue, sort,
spacing, sqrt, square, squeeze, stack, std, str_, subtract, sum,
swapaxes, take, tan, tanh, tensordot, timedelta64, trace, transpose,
true_divide, trunc, typecodes, ubyte, ufunc, uint, uint16, uint32,
uint64, uint8, uintc, uintp, ulong, ulonglong, unsignedinteger,
ushort, var, vdot, void, vstack, where, zeros, zeros_like
)

# NOTE: It's still under discussion whether these aliases
Expand Down
13 changes: 13 additions & 0 deletions numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3248,6 +3248,19 @@ true_divide: _UFunc_Nin2_Nout1[L['true_divide'], L[11], None]
trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]

abs = absolute
acos = arccos
acosh = arccosh
asin = arcsin
asinh = arcsinh
atan = arctan
atanh = arctanh
atan2 = arctan2
concat = concatenate
bitwise_left_shift = left_shift
bitwise_invert = invert
bitwise_right_shift = right_shift
permute_dims = transpose
pow = power

class _CopyMode(enum.Enum):
ALWAYS: L[True]
Expand Down
20 changes: 19 additions & 1 deletion numpy/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,25 @@
from . import _dtype
from . import _methods

__all__ = ['memmap', 'sctypeDict', 'record', 'recarray', 'abs']
acos = numeric.arccos
acosh = numeric.arccosh
asin = numeric.arcsin
asinh = numeric.arcsinh
atan = numeric.arctan
atanh = numeric.arctanh
atan2 = numeric.arctan2
concat = numeric.concatenate
bitwise_left_shift = numeric.left_shift
bitwise_invert = numeric.invert
bitwise_right_shift = numeric.right_shift
permute_dims = numeric.transpose
pow = numeric.power

__all__ = [
"abs", "acos", "acosh", "asin", "asinh", "atan", "atanh", "atan2",
"bitwise_invert", "bitwise_left_shift", "bitwise_right_shift", "concat",
"pow", "permute_dims", "memmap", "sctypeDict", "record", "recarray"
]
__all__ += numeric.__all__
__all__ += function_base.__all__
__all__ += getlimits.__all__
Expand Down
2 changes: 2 additions & 0 deletions numpy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
cross
multi_dot
matrix_power
tensordot
matmul

Decompositions
--------------
Expand Down
5 changes: 5 additions & 0 deletions numpy/linalg/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ from numpy.linalg._linalg import (
cond as cond,
matrix_rank as matrix_rank,
multi_dot as multi_dot,
matmul as matmul,
trace as trace,
diagonal as diagonal,
cross as cross,
)

from numpy._core.numeric import (
tensordot as tensordot,
)

from numpy._pytesttester import PytestTester

__all__: list[str]
Expand Down
63 changes: 61 additions & 2 deletions numpy/linalg/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det',
'svd', 'svdvals', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond',
'matrix_rank', 'LinAlgError', 'multi_dot', 'trace', 'diagonal',
'cross', 'outer']
'cross', 'outer', 'tensordot', 'matmul']

import functools
import operator
Expand All @@ -28,7 +28,8 @@
amax, prod, abs, atleast_2d, intp, asanyarray, object_, matmul,
swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace,
cross as _core_cross, outer as _core_outer
cross as _core_cross, outer as _core_outer, tensordot as _core_tensordot,
matmul as _core_matmul,
)
from numpy.lib._twodim_base_impl import triu, eye
from numpy.lib.array_utils import normalize_axis_index
Expand Down Expand Up @@ -3129,3 +3130,61 @@ def cross(x1, x2, /, *, axis=-1):
)

return _core_cross(x1, x2, axis=axis)


# matmul

def _matmul_dispatcher(x1, x2, /):
return (x1, x2)


@array_function_dispatch(_matmul_dispatcher)
def matmul(x1, x2, /):
"""
Computes the matrix product.

This function is Array API compatible, contrary to
:func:`numpy.matmul`.

Parameters
----------
x1 : array_like
The first input array.
x2 : array_like
The second input array.

Returns
-------
out : ndarray
The matrix product of the inputs.
This is a scalar only when both ``x1``, ``x2`` are 1-d vectors.

Raises
------
ValueError
If the last dimension of ``x1`` is not the same size as
the second-to-last dimension of ``x2``.

If a scalar value is passed in.

See Also
--------
numpy.matmul

"""
return _core_matmul(x1, x2)


# tensordot

def _tensordot_dispatcher(
x1, x2, /, *, offset=None, dtype=None):
return (x1, x2)


@array_function_dispatch(_tensordot_dispatcher)
def tensordot(x1, x2, /, *, axes=2):
return _core_tensordot(x1, x2, axes=axes)


tensordot.__doc__ = _core_tensordot.__doc__
Loading
0