8000 API: Add matmul to linalg · numpy/numpy@0c3cad6 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0c3cad6

Browse files
committed
API: Add matmul to linalg
1 parent 2c5603c commit 0c3cad6

File tree

11 files changed

+92
-57
lines changed

11 files changed

+92
-57
lines changed

doc/release/upcoming_changes/25086.new_feature.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ Array API compatible functions' aliases
99

1010
* Misc: ``concat``, ``permute_dims``, ``pow``.
1111

12-
* linalg: ``tensordot``
12+
* linalg: ``tensordot``, ``matmul``.

doc/source/reference/array_api.rst

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,6 @@ The following functions are named differently in the array API
7676
- ``matrix_norm`` and ``vector_norm`` each do a limited subset of what
7777
``np.norm`` does.
7878

79-
80-
``linalg`` Namespace Differences
81-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
82-
83-
These functions are in the ``linalg`` sub-namespace in the array API, but are
84-
only in the top-level namespace in NumPy:
85-
86-
- ``matmul`` (*)
87-
88-
(*): These functions are also in the top-level namespace in the array API.
89-
9079
Keyword Argument Renames
9180
~~~~~~~~~~~~~~~~~~~~~~~~
9281

doc/source/reference/routines.linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Matrix and vector products
5959
inner
6060
outer
6161
matmul
62+
linalg.matmul (Array API compatible location)
6263
tensordot
6364
linalg.tensordot (Array API compatible location)
6465
einsum

numpy/_core/tests/test_numeric.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4010,14 +4010,6 @@ def test_raise(self):
40104010

40114011
class TestTensordot:
40124012

4013-
@pytest.mark.parametrize(
4014-
"tensordot_func",
4015-
[np.tensordot, np.linalg.tensordot]
4016-
)
4017-
def test_tensordot(self, tensordot_func):
4018-
arr = np.arange(6).reshape((2, 3))
4019-
assert tensordot_func(arr, arr) == 55
4020-
40214013
def test_zero_dimension(self):
40224014
# Test resolution to issue #5663
40234015
a = np.ndarray((3,0))

numpy/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
multi_dot
3131
matrix_power
3232
tensordot
33+
matmul
3334
3435
Decompositions
3536
--------------

numpy/linalg/__init__.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ from numpy.linalg._linalg import (
2121
cond as cond,
2222
matrix_rank as matrix_rank,
2323
multi_dot as multi_dot,
24+
matmul as matmul,
2425
trace as trace,
2526
diagonal as diagonal,
2627
cross as cross,
2728
)
2829

2930
from numpy._core.numeric import (
30-
tensordot as tensordot
31+
tensordot as tensordot,
3132
)
3233

3334
from numpy._pytesttester import PytestTester

numpy/linalg/_linalg.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det',
1414
'svd', 'svdvals', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond',
1515
'matrix_rank', 'LinAlgError', 'multi_dot', 'trace', 'diagonal',
16-
'cross', 'outer', 'tensordot']
16+
'cross', 'outer', 'tensordot', 'matmul']
1717

1818
import functools
1919
import operator
@@ -29,6 +29,7 @@
2929
swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,
3030
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace,
3131
cross as _core_cross, outer as _core_outer, tensordot as _core_tensordot,
32+
matmul as _core_matmul,
3233
)
3334
from numpy.lib._twodim_base_impl import triu, eye
3435
from numpy.lib.array_utils import normalize_axis_index
@@ -3131,6 +3132,49 @@ def cross(x1, x2, /, *, axis=-1):
31313132
return _core_cross(x1, x2, axis=axis)
31323133

31333134

3135+
# matmul
3136+
3137+
def _matmul_dispatcher(x1, x2, /):
3138+
return (x1, x2)
3139+
3140+
3141+
@array_function_dispatch(_matmul_dispatcher)
3142+
def matmul(x1, x2, /):
3143+
"""
3144+
Computes the matrix product.
3145+
3146+
This function is Array API compatible, contrary to
3147+
:func:`numpy.matmul`.
3148+
3149+
Parameters
3150+
----------
3151+
x1 : array_like
3152+
The first input array.
3153+
x2 : array_like
3154+
The second input array.
3155+
3156+
Returns
3157+
-------
3158+
out : ndarray
3159+
The matrix product of the inputs.
3160+
This is a scalar only when both ``x1``, ``x2`` are 1-d vectors.
3161+
3162+
Raises
3163+
------
3164+
ValueError
3165+
If the last dimension of ``x1`` is not the same size as
3166+
the second-to-last dimension of ``x2``.
3167+
3168+
If a scalar value is passed in.
3169+
3170+
See Also
3171+
--------
3172+
numpy.matmul
3173+
3174+
"""
3175+
return _core_matmul(x1, x2)
3176+
3177+
31343178
# tensordot
31353179

31363180
def _tensordot_dispatcher(

numpy/linalg/_linalg.pyi

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,24 @@ def cross(
373373
b: _ArrayLikeComplex_co,
374374
axis: int = ...,
375375
) -> NDArray[complexfloating[Any, Any]]: ...
376+
377+
@overload
378+
def matmul(
379+
x1: _ArrayLikeInt_co,
380+
x2: _ArrayLikeInt_co,
381+
) -> NDArray[signedinteger[Any]]: ...
382+
@overload
383+
def matmul(
384+
x1: _ArrayLikeUInt_co,
385+
x2: _ArrayLikeUInt_co,
386+
) -> NDArray[unsignedinteger[Any]]: ...
387+
@overload
388+
def matmul(
389+
x1: _ArrayLikeFloat_co,
390+
x2: _ArrayLikeFloat_co,
391+
) -> NDArray[floating[Any]]: ...
392+
@overload
393+
def matmul(
394+
x1: _ArrayLikeComplex_co,
395+
x2: _ArrayLikeComplex_co,
396+
) -> NDArray[complexfloating[Any, Any]]: ...

numpy/linalg/tests/test_linalg.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2254,7 +2254,6 @@ def test_trace():
22542254

22552255

22562256
def test_cross():
2257-
22582257
x = np.arange(9).reshape((3, 3))
22592258
actual = np.linalg.cross(x, x + 1)
22602259
expected = np.array([
@@ -2271,3 +2270,20 @@ def test_cross():
22712270
):
22722271
x_2dim = x[:, 1:]
22732272
np.linalg.cross(x_2dim, x_2dim)
2273+
2274+
2275+
def test_tensordot():
2276+
# np.linalg.tensordot is just an alias for np.tensordot
2277+
x = np.arange(6).reshape((2, 3))
2278+
2279+
assert np.linalg.tensordot(x, x) == 55
2280+
2281+
2282+
def test_matmul():
2283+
# np.linalg.matmul and np.matmul only differs in the number
2284+
# of arguments in the signature
2285+
x = np.arange(6).reshape((2, 3))
2286+
actual = np.linalg.matmul(x, x.T)
2287+
expected = np.array([[5, 14], [14, 50]])
2288+
2289+
assert_equal(actual, expected)

numpy/typing/tests/data/reveal/linalg.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,7 @@ assert_type(np.linalg.multi_dot([AR_m, AR_m]), Any)
118118
assert_type(np.linalg.cross(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]])
119119
assert_type(np.linalg.cross(AR_f8, AR_f8), npt.NDArray[np.floating[Any]])
120120
assert_type(np.linalg.cross(AR_c16, AR_c16), npt.NDArray[np.complexfloating[Any, Any]])
121+
122+
assert_type(np.linalg.matmul(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]])
123+
assert_type(np.linalg.matmul(AR_f8, AR_f8), npt.NDArray[np.floating[Any]])
124+
assert_type(np.linalg.matmul(AR_c16, AR_c16), npt.NDArray[np.complexfloating[Any, Any]])

tools/ci/array-api-skips.txt

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,11 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32]
4141
array_api_tests/test_data_type_functions.py::test_isdtype
4242

4343
# missing names
44-
array_api_tests/test_has_names.py::test_has_names[linalg-matmul]
4544
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_norm]
4645
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose]
47-
array_api_tests/test_has_names.py::test_has_names[linalg-tensordot]
4846
array_api_tests/test_has_names.py::test_has_names[linalg-vecdot]
4947
array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm]
50-
array_api_tests/test_has_names.py::test_has_names[manipulation-concat]
51-
array_api_tests/test_has_names.py::test_has_names[manipulation-permute_dims]
5248
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
53-
array_api_tests/test_has_names.py::test_has_names[elementwise-acos]
54-
array_api_tests/test_has_names.py::test_has_names[elementwise-acosh]
55-
array_api_tests/test_has_names.py::test_has_names[elementwise-asin]
56-
array_api_tests/test_has_names.py::test_has_names[elementwise-asinh]
57-
array_api_tests/test_has_names.py::test_has_names[elementwise-atan]
58-
array_api_tests/test_has_names.py::test_has_names[elementwise-atan2]
59-
array_api_tests/test_has_names.py::test_has_names[elementwise-atanh]
60-
array_api_tests/test_has_names.py::test_has_names[elementwise-bitwise_left_shift]
61-
array_api_tests/test_has_names.py::test_has_names[elementwise-bitwise_invert]
62-
array_api_tests/test_has_names.py::test_has_names[elementwise-bitwise_right_shift]
63-
array_api_tests/test_has_names.py::test_has_names[elementwise-pow]
6449
array_api_tests/test_has_names.py::test_has_names[linear_algebra-matrix_transpose]
6550
array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
6651
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
@@ -72,10 +57,6 @@ array_api_tests/test_linalg.py::test_matrix_transpose
7257
array_api_tests/test_linalg.py::test_pinv
7358
array_api_tests/test_linalg.py::test_vecdot
7459

75-
# missing names
76-
array_api_tests/test_manipulation_functions.py::test_concat
77-
array_api_tests/test_manipulation_functions.py::test_permute_dims
78-
7960
# a few misalignments
8061
array_api_tests/test_operators_and_elementwise_functions.py
8162
array_api_tests/test_signatures.py::test_func_signature[std]
@@ -89,32 +70,17 @@ array_api_tests/test_signatures.py::test_func_signature[linspace]
8970
array_api_tests/test_signatures.py::test_func_signature[ones]
9071
array_api_tests/test_signatures.py::test_func_signature[ones_like]
9172
array_api_tests/test_signatures.py::test_func_signature[zeros_like]
92-
array_api_tests/test_signatures.py::test_func_signature[concat]
93-
array_api_tests/test_signatures.py::test_func_signature[permute_dims]
9473
array_api_tests/test_signatures.py::test_func_signature[reshape]
9574
array_api_tests/test_signatures.py::test_func_signature[argsort]
9675
array_api_tests/test_signatures.py::test_func_signature[sort]
9776
array_api_tests/test_signatures.py::test_func_signature[isdtype]
98-
array_api_tests/test_signatures.py::test_func_signature[acos]
99-
array_api_tests/test_signatures.py::test_func_signature[acosh]
100-
array_api_tests/test_signatures.py::test_func_signature[asin]
101-
array_api_tests/test_signatures.py::test_func_signature[asinh]
102-
array_api_tests/test_signatures.py::test_func_signature[atan]
103-
array_api_tests/test_signatures.py::test_func_signature[atan2]
104-
array_api_tests/test_signatures.py::test_func_signature[atanh]
105-
array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
106-
array_api_tests/test_signatures.py::test_func_signature[bitwise_invert]
107-
array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
108-
array_api_tests/test_signatures.py::test_func_signature[pow]
10977
array_api_tests/test_signatures.py::test_func_signature[matrix_transpose]
11078
array_api_tests/test_signatures.py::test_func_signature[vecdot]
111-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matmul]
11279
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cholesky]
11380
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_norm]
11481
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_rank]
11582
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_transpose]
11683
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv]
117-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.tensordot]
11884
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
11985
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vector_norm]
12086
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]

0 commit comments

Comments
 (0)
0