10000 API: Add matmul to linalg · numpy/numpy@3236f0f · GitHub
[go: up one dir, main page]

Skip to content

Commit 3236f0f

Browse files
committed
API: Add matmul to linalg
1 parent 38f0c48 commit 3236f0f

File tree

12 files changed

+92
-57
lines changed

12 files changed

+92
-57
lines changed

doc/release/upcoming_changes/25086.new_feature.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ Array API compatible functions' aliases
99

1010
* Misc: ``concat``, ``permute_dims``, ``pow``.
1111

12-
* linalg: ``tensordot``
12+
* linalg: ``tensordot``, ``matmul``.

doc/source/reference/array_api.rst

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,16 +84,6 @@ Function instead of method
8484
``ndarray`` in ``numpy``.
8585

8686

87-
``linalg`` Namespace Differences
88-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
89-
90-
These functions are in the ``linalg`` sub-namespace in the array API, but are
91-
only in the top-level namespace in NumPy:
92-
93-
- ``matmul`` (*)
94-
95-
(*): These functions are also in the top-level namespace in the array API.
96-
9787
Keyword Argument Renames
9888
~~~~~~~~~~~~~~~~~~~~~~~~
9989

doc/source/reference/routines.linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Matrix and vector products
5959
inner
6060
outer
6161
matmul
62+
linalg.matmul (Array API compatible location)
6263
tensordot
6364
linalg.tensordot (Array API compatible location)
6465
einsum

numpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@
134134
can_cast, cbrt, cdouble, ceil, character, choose, clip, clongdouble,
135135
complex128, complex64, complexfloating, compress, concat, concatenate,
136136
conj, conjugate, convolve, copysign, copyto, correlate, cos, cosh,
137-
count_nonzero, cross, csingle,cumprod, cumproduct, cumsum, datetime64,
137+
count_nonzero, cross, csingle, cumprod, cumproduct, cumsum, datetime64,
138138
datetime_as_string, datetime_data, deg2rad, degrees, diagonal, divide,
139139
divmod, dot, double, dtype, e, einsum, einsum_path, empty, empty_like,
140140
equal, errstate, euler_gamma, exp, exp2, expm1, fabs, finfo, flatiter,

numpy/_core/tests/test_numeric.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4010,14 +4010,6 @@ def test_raise(self):
40104010

40114011
class TestTensordot:
40124012

4013-
@pytest.mark.parametrize(
4014-
"tensordot_func",
4015-
[np.tensordot, np.linalg.tensordot]
4016-
)
4017-
def test_tensordot(self, tensordot_func):
4018-
arr = np.arange(6).reshape((2, 3))
4019-
assert tensordot_func(arr, arr) == 55
4020-
40214013
def test_zero_dimension(self):
40224014
# Test resolution to issue #5663
40234015
a = np.ndarray((3,0))

numpy/linalg/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
multi_dot
3131
matrix_power
3232
tensordot
33+
matmul
3334
3435
Decompositions
3536
--------------

numpy/linalg/__init__.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ from numpy.linalg._linalg import (
2121
cond as cond,
2222
matrix_rank as matrix_rank,
2323
multi_dot as multi_dot,
24+
matmul as matmul,
2425
trace as trace,
2526
diagonal as diagonal,
2627
cross as cross,
2728
)
2829

2930
from numpy._core.numeric import (
30-
tensordot as tensordot
31+
tensordot as tensordot,
3132
)
3233

3334
from numpy._pytesttester import PytestTester

numpy/linalg/_linalg.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
'cholesky', 'eigvals', 'eigvalsh', 'pinv', 'slogdet', 'det',
1414
'svd', 'svdvals', 'eig', 'eigh', 'lstsq', 'norm', 'qr', 'cond',
1515
'matrix_rank', 'LinAlgError', 'multi_dot', 'trace', 'diagonal',
16-
'cross', 'outer', 'tensordot']
16+
'cross', 'outer', 'tensordot', 'matmul']
1717

1818
import functools
1919
import operator
@@ -29,6 +29,7 @@
2929
swapaxes, divide, count_nonzero, isnan, sign, argsort, sort,
3030
reciprocal, overrides, diagonal as _core_diagonal, trace as _core_trace,
3131
cross as _core_cross, outer as _core_outer, tensordot as _core_tensordot,
32+
matmul as _core_matmul,
3233
)
3334
from numpy.lib._twodim_base_impl import triu, eye
3435
from numpy.lib.array_utils import normalize_axis_index
@@ -3131,6 +3132,49 @@ def cross(x1, x2, /, *, axis=-1):
31313132
return _core_cross(x1, x2, axis=axis)
31323133

31333134

3135+
# matmul
3136+
3137+
def _matmul_dispatcher(x1, x2, /):
3138+
return (x1, x2)
3139+
3140+
3141+
@array_function_dispatch(_matmul_dispatcher)
3142+
def matmul(x1, x2, /):
3143+
"""
3144+
Computes the matrix product.
3145+
3146+
This function is Array API compatible, contrary to
3147+
:func:`numpy.matmul`.
3148+
3149+
Parameters
3150+
----------
3151+
x1 : array_like
3152+
The first input array.
3153+
x2 : array_like
3154+
The second input array.
3155+
3156+
Returns
3157+
-------
3158+
out : ndarray
3159+
The matrix product of the inputs.
3160+
This is a scalar only when both ``x1``, ``x2`` are 1-d vectors.
3161+
3162+
Raises
3163+
------
3164+
ValueError
3165+
If the last dimension of ``x1`` is not the same size as
3166+
the second-to-last dimension of ``x2``.
3167+
3168+
If a scalar value is passed in.
3169+
3170+
See Also
3171+
--------
3172+
numpy.matmul
3173+
3174+
"""
3175+
return _core_matmul(x1, x2)
3176+
3177+
31343178
# tensordot
31353179

31363180
def _tensordot_dispatcher(

numpy/linalg/_linalg.pyi

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,23 @@ def cross(
373373
b: _ArrayLikeComplex_co,
374374
axis: int = ...,
375375
) -> NDArray[complexfloating[Any, Any]]: ...
376+
377+
def matmul(
378+
x1: _ArrayLikeInt_co,
379+
x2: _ArrayLikeInt_co,
380+
) -> NDArray[signedinteger[Any]]: ...
381+
@overload
382+
def matmul(
383+
x1: _ArrayLikeUInt_co,
384+
x2: _ArrayLikeUInt_co,
385+
) -> NDArray[unsignedinteger[Any]]: ...
386+
@overload
387+
def matmul(
388+
x1: _ArrayLikeFloat_co,
389+
x2: _ArrayLikeFloat_co,
390+
) -> NDArray[floating[Any]]: ...
391+
@overload
392+
def matmul(
393+
x1: _ArrayLikeComplex_co,
394+
x2: _ArrayLikeComplex_co,
395+
) -> NDArray[complexfloating[Any, Any]]: ...

numpy/linalg/tests/test_linalg.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2254,7 +2254,6 @@ def test_trace():
22542254

22552255

22562256
def test_cross():
2257-
22582257
x = np.arange(9).reshape((3, 3))
22592258
actual = np.linalg.cross(x, x + 1)
22602259
expected = np.array([
@@ -2271,3 +2270,20 @@ d 2851 ef test_cross():
22712270
):
22722271
x_2dim = x[:, 1:]
22732272
np.linalg.cross(x_2dim, x_2dim)
2273+
2274+
2275+
def test_tensordot():
2276+
# np.linalg.tensordot is just an alias for np.tensordot
2277+
x = np.arange(6).reshape((2, 3))
2278+
2279+
assert np.linalg.tensordot(x, x) == 55
2280+
2281+
2282+
def test_matmul():
2283+
# np.linalg.matmul and np.matmul only differs in the number
2284+
# of arguments in the signature
2285+
x = np.arange(6).reshape((2, 3))
2286+
actual = np.linalg.matmul(x, x.T)
2287+
expected = np.array([[5, 14], [14, 50]])
2288+
2289+
assert_equal(actual, expected)

numpy/typing/tests/data/reveal/linalg.pyi

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,7 @@ assert_type(np.linalg.multi_dot([AR_m, AR_m]), Any)
118118
assert_type(np.linalg.cross(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]])
119119
assert_type(np.linalg.cross(AR_f8, AR_f8), npt.NDArray[np.floating[Any]])
120120
assert_type(np.linalg.cross(AR_c16, AR_c16), npt.NDArray[np.complexfloating[Any, Any]])
121+
122+
assert_type(np.linalg.matmul(AR_i8, AR_i8), npt.NDArray[np.signedinteger[Any]])
123+
assert_type(np.linalg.matmul(AR_f8, AR_f8), npt.NDArray[np.floating[Any]])
124+
assert_type(np.linalg.matmul(AR_c16, AR_c16), npt.NDArray[np.complexfloating[Any, Any]])

tools/ci/array-api-skips.txt

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -42,27 +42,12 @@ array_api_tests/test_data_type_functions.py::test_isdtype
4242
array_api_tests/test_data_type_functions.py::test_astype
4343

4444
# missing names
45-
array_api_tests/test_has_names.py::test_has_names[linalg-matmul]
4645
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_norm]
4746
array_api_tests/test_has_names.py::test_has_names[linalg-matrix_transpose]
48-
array_api_tests/test_has_names.py::test_has_names[linalg-tensordot]
4947
array_api_tests/test_has_names.py::test_has_names[linalg-vecdot]
5048
array_api_tests/test_has_names.py::test_has_names[linalg-vector_norm]
51-
array_api_tests/test_has_names.py::test_has_names[manipulation-concat]
52-
array_api_tests/test_has_names.py::test_has_names[manipulation-permute_dims]
5349
array_api_tests/test_has_names.py::test_has_names[data_type-astype]
5450
array_api_tests/test_has_names.py::test_has_names[data_type-isdtype]
55-
array_api_tests/test_has_names.py::test_has_names[elementwise-acos]
56-
array_api_tests/test_has_names.py::test_has_names[elementwise-acosh]
57-
array_api_tests/test_has_names.py::test_has_names[elementwise-asin]
58-
array_api_tests/test_has_names.py::test_has_names[elementwise-asinh]
59-
array_api_tests/test_has_names.py::test_has_names[elementwise-atan]
60-
array_api_tests/test_has_names.py::test_has_names[elementwise-atan2]
61-
array_api_tests/test_has_names.py::test_has_names[elementwise-atanh]
62-
array_api_tests/test_has_names.py::test_has_names[elementwise-bitwise_left_shift]
63-
array_api_tests/test_has_names.py::test_has_names[elementwise-bitwise_invert]
64-
array_api_tests/test_has_names.py::test_has_names[elementwise-bitwise_right_shift]
65-
array_api_tests/test_has_names.py::test_has_names[elementwise-pow]
6651
array_api_tests/test_has_names.py::test_has_names[linear_algebra-matrix_transpose]
6752
array_api_tests/test_has_names.py::test_has_names[linear_algebra-vecdot]
6853
array_api_tests/test_has_names.py::test_has_names[array_method-to_device]
@@ -74,10 +59,6 @@ array_api_tests/test_linalg.py::test_matrix_transpose
7459
array_api_tests/test_linalg.py::test_pinv
7560
array_api_tests/test_linalg.py::test_vecdot
7661

77-
# missing names
78-
array_api_tests/test_manipulation_functions.py::test_concat
79-
array_api_tests/test_manipulation_functions.py::test_permute_dims
80-
8162
# a few misalignments
8263
array_api_tests/test_operators_and_elementwise_functions.py
8364
array_api_tests/test_signatures.py::test_func_signature[std]
@@ -91,33 +72,18 @@ array_api_tests/test_signatures.py::test_func_signature[linspace]
9172
array_api_tests/test_signatures.py::test_func_signature[ones]
9273
array_api_tests/test_signatures.py::test_func_signature[ones_like]
9374
array_api_tests/test_signatures.py::test_func_signature[zeros_like]
94-
array_api_tests/test_signatures.py::test_func_signature[concat]
95-
array_api_tests/test_signatures.py::test_func_signature[permute_dims]
9675
array_api_tests/test_signatures.py::test_func_signature[reshape]
9776
array_api_tests/test_signatures.py::test_func_signature[argsort]
9877
array_api_tests/test_signatures.py::test_func_signature[sort]
9978
array_api_tests/test_signatures.py::test_func_signature[astype]
10079
array_api_tests/test_signatures.py::test_func_signature[isdtype]
101-
array_api_tests/test_signatures.py::test_func_signature[acos]
102-
array_api_tests/test_signatures.py::test_func_signature[acosh]
103-
array_api_tests/test_signatures.py::test_func_signature[asin]
104-
array_api_tests/test_signatures.py::test_func_signature[asinh]
105-
array_api_tests/test_signatures.py::test_func_signature[atan]
106-
array_api_tests/test_signatures.py::test_func_signature[atan2]
107-
array_api_tests/test_signatures.py::test_func_signature[atanh]
108-
array_api_tests/test_signatures.py::test_func_signature[bitwise_left_shift]
109-
array_api_tests/test_signatures.py::test_func_signature[bitwise_invert]
110-
array_api_tests/test_signatures.py::test_func_signature[bitwise_right_shift]
111-
array_api_tests/test_signatures.py::test_func_signature[pow]
11280
array_api_tests/test_signatures.py::test_func_signature[matrix_transpose]
11381
array_api_tests/test_signatures.py::test_func_signature[vecdot]
114-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matmul]
11582
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.cholesky]
11683
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_norm]
11784
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_rank]
11885
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.matrix_transpose]
11986
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.pinv]
120-
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.tensordot]
12187
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vecdot]
12288
array_api_tests/test_signatures.py::test_extension_func_signature[linalg.vector_norm]
12389
array_api_tests/test_signatures.py::test_array_method_signature[__array_namespace__]

0 commit comments

Comments
 (0)
0