8000 API: Add Array API aliases (math, bitwise, linalg, misc) [Array API] by mtsokol · Pull Request #25086 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
API: Add tensordot to linalg
  • Loading branch information
mtsokol committed Dec 8, 2023
commit 2c5603c215e4837746c5e5a336981f2778c7f9f4
2 changes: 2 additions & 0 deletions doc/release/upcoming_changes/25086.new_feature.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ Array API compatible functions' aliases
* Bitwise: ``bitwise_left_shift``, ``bitwise_invert``, ``bitwise_right_shift``.

* Misc: ``concat``, ``permute_dims``, ``pow``.

* linalg: ``tensordot``
1 change: 0 additions & 1 deletion doc/source/reference/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ These functions are in the ``linalg`` sub-namespace in the array API, but are
only in the top-level namespace in NumPy:

- ``matmul`` (*)
- ``tensordot`` (*)

(*): These functions are also in the top-level namespace in the array API.

Expand Down
1 change: 1 addition & 0 deletions doc/source/reference/routines.linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Matrix and vector products
outer
matmul
tensordot
linalg.tensordot (Array API compatible location)
einsum
einsum_path
linalg.matrix_power
Expand Down
8 changes: 8 additions & 0 deletions numpy/_core/tests/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -4010,6 +4010,14 @@ def test_raise(self):

class TestTensordot:

@pytest.mark.parametrize(
"tensordot_func",
[np.tensordot, np.linalg.tensordot]
)
def test_tensordot(self, tensordot_func):
arr = np.arange(6).reshape((2, 3))
assert tensordot_func(arr, arr) == 55

def test_zero_dimension(self):
# Test resolution to issue #5663
a = np.ndarray((3,0))
Expand Down
1 change: 1 addition & 0 deletions numpy/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
cross
multi_dot
matrix_power
tensordot

Decompositions
--------------
Expand Down
4 changes: 4 additions & 0 deletions numpy/linalg/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ from numpy.linalg._linalg import (
cross as cross,
)

from numpy._core.numeric import (
tensordot as tensordot
)

from numpy._pytesttester import PytestTester

__all__: list[str]
Expand Down
19 changes: 17 additions & 2 deletions numpy/linalg/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det',
'svd', 'svdvals', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond',
'matrix_rank', 'LinAlgError', 'multi_dot', 'trace', 'diagonal',
'cross', 'outer']
'cross', 'outer', 'tensordot']

import functools
import operator
Expand All @@ -28,7 +28,7 @@
amax, prod, abs, atleast_2d, intp, asanyarray, object_, matmul,
swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace,
cross as _core_cross, outer as _core_outer
cross as _core_cross, outer as _core_outer, tensordot as _core_tensordot,
)
from numpy.lib._twodim_base_impl import triu, eye
from numpy.lib.array_utils import normalize_axis_index
Expand Down Expand Up @@ -3129,3 +3129,18 @@ def cross(x1, x2, /, *, axis=-1):
)

return _core_cross(x1, x2, axis=axis)


# tensordot

def _tensordot_dispatcher(
x1, x2, /, *, offset=None, dtype=None):
return (x1, x2)


@array_function_dispatch(_tensordot_dispatcher)
def tensordot(x1, x2, /, *, axes=2):
return _core_tensordot(x1, x2, axes=axes)


tensordot.__doc__ = _core_tensordot.__doc__
0