8000 API: add matrix_norm, vector_norm, vecdot and matrix_transpose · numpy/numpy@4f392bf · GitHub
[go: up one dir, main page]

Skip to content
/* Override primer focus outline color for marketing header dropdown links for better contrast */ [data-color-mode="light"] .HeaderMenu-dropdown-link:focus-visible, [data-color-mode="light"] .HeaderMenu-trailing-link a:focus-visible { outline-color: var(--color-accent-fg); }

Commit 4f392bf

Browse files
committed
API: add matrix_norm, vector_norm, vecdot and matrix_transpose
1 parent b3d71f7 commit 4f392bf

File tree

17 files changed

+388
-21
lines changed

17 files changed

+388
-21
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Array API compatible functions for ``numpy.linalg``
2+
---------------------------------------------------
3+
4+
Four new functions were added to improve compatibility with
5+
the Array API standard for `numpy.linalg`:
6+
7+
* `numpy.linalg.matrix_norm`
8+
9+
* `numpy.linalg.vector_norm`
10+
11+
* `numpy.vecdot` and `numpy.linalg.vecdot`
12+
13+
* `numpy.matrix_transpose` and `numpy.linalg.matrix_transpose`

doc/source/reference/routines.array-manipulation.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Transpose-like operations
3232
swapaxes
3333
ndarray.T
3434
transpose
35+
matrix_transpose (Array API compatible)
3536

3637
Changing number of dimensions
3738
=============================

doc/source/reference/routines.linalg.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Matrix and vector products
5656
dot
5757
linalg.multi_dot
5858
vdot
59+
vecdot
5960
inner
6061
outer
6162
matmul
@@ -90,6 +91,8 @@ Norms and other numbers
9091
:toctree: generated/
9192

9293
linalg.norm
94+
linalg.matrix_norm (Array API compatible)
95+
linalg.vector_norm (Array API compatible)
9396
linalg.cond
9497
linalg.det
9598
linalg.matrix_rank
@@ -114,6 +117,7 @@ Other matrix operations
114117

115118
diagonal
116119
linalg.diagonal (Array API compatible)
120+
linalg.matrix_transpose (Array API compatible)
117121

118122
Exceptions
119123
----------

numpy/__init__.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,22 +148,22 @@
148148
isscalar, issubdtype, lcm, ldexp, left_shift, less, less_equal,
149149
lexsort, linspace, little_endian, log, log10, log1p, log2, logaddexp,
150150
logaddexp2, logical_and, logical_not, logical_or, logical_xor,
151-
logspace, long, longdouble, longlong, matmul, max, maximum,
152-
may_share_memory, mean, memmap, min, min_scalar_type, minimum, mod,
153-
modf, moveaxis, multiply, nan, ndarray, ndim, nditer, negative,
154-
nested_iters, newaxis, nextafter, nonzero, not_equal, number, object_,
155-
ones, ones_like, outer, partition, pi, positive, power, printoptions,
156-
prod, product, promote_types, ptp, put, putmask, rad2deg, radians,
157-
ravel, recarray, reciprocal, record, remainder, repeat, require,
158-
reshape, resize, result_type, right_shift, rint, roll, rollaxis,
159-
round, sctypeDict, searchsorted, set_printoptions, setbufsize, seterr,
160-
seterrcall, shape, shares_memory, short, sign, signbit, signedinteger,
161-
sin, single, sinh, size, sometrue, sort, spacing, sqrt, square,
162-
squeeze, stack, std, str_, subtract, sum, swapaxes, take, tan, tanh,
163-
tensordot, timedelta64, trace, transpose, true_divide, trunc,
164-
typecodes, ubyte, ufunc, uint, uint16, uint32, uint64, uint8, uintc,
165-
uintp, ulong, ulonglong, unsignedinteger, ushort, var, vdot, void,
166-
vstack, where, zeros, zeros_like
151+
logspace, long, longdouble, longlong, matmul, matrix_transpose, max,
152+
maximum, may_share_memory, mean, memmap, min, min_scalar_type,
153+
minimum, mod, modf, moveaxis, multiply, nan, ndarray, ndim, nditer,
154+
negative, nested_iters, newaxis, nextafter, nonzero, not_equal,
155+
number, object_, ones, ones_like, outer, partition, pi, positive,
156+
power, printoptions, prod, product, promote_types, ptp, put, putmask,
157+
rad2deg, radians, ravel, recarray, reciprocal, record, remainder,
158+
repeat, require, reshape, resize, result_type, right_shift, rint,
159+
roll, rollaxis, round, sctypeDict, searchsorted, set_printoptions,
160+
setbufsize, seterr, seterrcall, shape, shares_memory, short, sign,
161+
signbit, signedinteger, sin, single, sinh, size, sometrue, sort,
162+
spacing, sqrt, square, squeeze, stack, std, str_, subtract, sum,
163+
swapaxes, take, tan, tanh, tensordot, timedelta64, trace, transpose,
164+
true_divide, trunc, typecodes, ubyte, ufunc, uint, uint16, uint32,
165+
uint64, uint8, uintc, uintp, ulong, ulonglong, unsignedinteger,
166+
ushort, var, vdot, vecdot, void, vstack, where, zeros, zeros_like
167167
)
168168

169169
# NOTE: It's still under discussion whether these aliases

numpy/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ from numpy._core.fromnumeric import (
235235
put as put,
236236
swapaxes as swapaxes,
237237
transpose as transpose,
238+
matrix_transpose as matrix_transpose,
238239
partition as partition,
239240
argpartition as argpartition,
240241
sort as sort,
@@ -362,6 +363,7 @@ from numpy._core.numeric import (
362363
convolve as convolve,
363364
outer as outer,
364365
tensordot as tensordot,
366+
vecdot as vecdot,
365367
roll as roll,
366368
rollaxis as rollaxis,
367369
moveaxis as moveaxis,

numpy/_core/fromnumeric.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
'all', 'alltrue', 'amax', 'amin', 'any', 'argmax',
2222
'argmin', 'argpartition', 'argsort', 'around', 'choose', 'clip',
2323
'compress', 'cumprod', 'cumproduct', 'cumsum', 'diagonal', 'mean',
24-
'max', 'min',
24+
'max', 'min', 'matrix_transpose',
2525
'ndim', 'nonzero', 'partition', 'prod', 'product', 'ptp', 'put',
2626
'ravel', 'repeat', 'reshape', 'resize', 'round',
2727
'searchsorted', 'shape', 'size', 'sometrue', 'sort', 'squeeze',
@@ -655,6 +655,41 @@ def transpose(a, axes=None):
655655
return _wrapfunc(a, 'transpose', axes)
656656

657657

658+
def _matrix_transpose_dispatcher(x):
659+
return (x,)
660+
661+
@array_function_dispatch(_matrix_transpose_dispatcher)
662+
def matrix_transpose(x, /):
663+
"""
664+
Transposes a matrix (or a stack of matrices) ``x``.
665+
666+
This function is Array API compatible.
667+
668+
Parameters
669+
----------
670+
x : array_like
671+
Input array having shape (..., M, N) and whose two innermost
672+
dimensions form ``MxN`` matrices.
673+
674+
Returns
675+
-------
676+
out : ndarray
677+
An array containing the transpose for each matrix and having shape
678+
(..., N, M).
679+
680+
See Also
681+
--------
682+
transpose : Generic transpose method.
683+
684+
"""
685+
x = asarray(x)
686+
if x.ndim < 2:
687+
raise ValueError(
688+
f"Input array must be at least 2-dimensional, but it is {x.ndim}"
689+
)
690+
return swapaxes(x, -1, -2)
691+
692+
658693
def _partition_dispatcher(a, kth, axis=None, kind=None, order=None):
659694
return (a,)
660695

numpy/_core/fromnumeric.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,11 @@ def transpose(
175175
axes: None | _ShapeLike = ...
176176
) -> NDArray[Any]: ...
177177

178+
@overload
179+
def matrix_transpose(x: _ArrayLike[_SCT]) -> NDArray[_SCT]: ...
180+
@overload
181+
def matrix_transpose(x: ArrayLike) -> NDArray[Any]: ...
182+
178183
@overload
179184
def partition(
180185
a: _ArrayLike[_SCT],

numpy/_core/numeric.py

Expand all lines: numpy/_core/numeric.py
Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
'can_cast', 'promote_types', 'min_scalar_type',
4949
'result_type', 'isfortran', 'empty_like', 'zeros_like', 'ones_like',
5050
'correlate', 'convolve', 'inner', 'dot', 'outer', 'vdot', 'roll',
51-
'rollaxis', 'moveaxis', 'cross', 'tensordot', 'little_endian',
51+
'rollaxis', 'moveaxis', 'cross', 'tensordot', 'vecdot', 'little_endian',
5252
'fromiter', 'array_equal', 'array_equiv', 'indices', 'fromfunction',
5353
'isclose', 'isscalar', 'binary_repr', 'base_repr', 'ones',
5454
'identity', 'allclose', 'putmask',
@@ -1121,6 +1121,53 @@ def tensordot(a, b, axes=2):
11211121
return res.reshape(olda + oldb)
11221122

11231123

1124+
def _vecdot_dispatcher(x1, x2, /, *, axis=None):
1125+
return (x1, x2)
1126+
1127+
1128+
@array_function_dispatch(_vecdot_dispatcher)
1129+
def vecdot(x1, x2, /, *, axis=-1):
1130+
"""
1131+
Computes the (vector) dot product of two arrays.
1132+
1133+
This function is Array API compatible.
1134+
1135+
Parameters
1136+
----------
1137+
x1 : array_like
1138+
First input array.
1139+
x2 : array_like
1140+
Second input array.
1141+
axis : int, optional
1142+
Axis over which to compute the dot product. Default: ``-1``.
1143+
1144+
Returns
1145+
-------
1146+
output : ndarray
1147+
The vector dot product of the input.
1148+
1149+
See Also
1150+
--------
1151+
dot
1152+
1153+
"""
1154+
ndim = builtins.max(x1.ndim, x2.ndim)
1155+
x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape)
1156+
x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape)
1157+
if x1_shape[axis] != x2_shape[axis]:
1158+
raise ValueError(
1159+
"x1 and x2 must have the same size along the dot-compute axis "
1160+
f"but they are: {x1_shape[axis]} and {x2_shape[axis]}."
1161+
)
1162+
1163+
x1_, x2_ = np.broadcast_arrays(x1, x2)
1164+
x1_ = np.moveaxis(x1_, axis, -1)
1165+
x2_ = np.moveaxis(x2_, axis, -1)
1166+
1167+
res = x1_[..., None, :] @ x2_[..., None]
1168+
return res[..., 0, 0]
1169+
1170+
11241171
def _roll_dispatcher(a, shift, axis=None):
11251172
return (a,)
11261173

numpy/_core/numeric.pyi

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,39 @@ def tensordot(
461461
axes: int | tuple[_ShapeLike, _ShapeLike] = ...,
462462
) -> NDArray[object_]: ...
463463

464+
@overload
465+
def vecdot(
466+
x1: _ArrayLikeUnknown, x2: _ArrayLikeUnknown, axis: int = ...
467+
) -> NDArray[Any]: ...
468+
@overload
469+
def vecdot(
470+
x1: _ArrayLikeBool_co, x2: _ArrayLikeBool_co, axis: int = ...
471+
) -> NDArray[bool_]: ...
472+
@overload
473+
def vecdot(
474+
x1: _ArrayLikeUInt_co, x2: _ArrayLikeUInt_co, axis: int = ...
475+
) -> NDArray[unsignedinteger[Any]]: ...
476+
@overload
477+
def vecdot(
478+
x1: _ArrayLikeInt_co, x2: _ArrayLikeInt_co, axis: int = ...
479+
) -> NDArray[signedinteger[Any]]: ...
480+
@overload
481+
def vecdot(
482+
x1: _ArrayLikeFloat_co, x2: _ArrayLikeFloat_co, axis: int = ...
483+
) -> NDArray[floating[Any]]: ...
484+
@overload
485+
def vecdot(
486+
x1: _ArrayLikeComplex_co, x2: _ArrayLikeComplex_co, axis: int = ...
487+
) -> NDArray[complexfloating[Any, Any]]: ...
488+
@overload
489+
def vecdot(
490+
x1: _ArrayLikeTD64_co, x2: _ArrayLikeTD64_co, axis: int = ...
491+
) -> NDArray[timedelta64]: ...
492+
@overload
493+
def vecdot(
494+
x1: _ArrayLikeObject_co, x2: _ArrayLikeObject_co, axis: int = ...
495+
) -> NDArray[object_]: ...
496+
464497
@overload
465498
def roll(
466499
a: _ArrayLike[_SCT],

numpy/_core/tests/test_numeric.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def test_transpose(self):
286286
arr = [[1, 2], [3, 4], [5, 6]]
287287
tgt = [[1, 3, 5], [2, 4, 6]]
288288
assert_equal(np.transpose(arr, (1, 0)), tgt)
289+
assert_equal(np.matrix_transpose(arr), tgt)
289290

290291
def test_var(self):
291292
A = [[1, 2, 3], [4, 5, 6]]
@@ -4023,3 +4024,15 @@ def test_zero_dimensional(self):
40234024
arr_0d = np.array(1)
40244025
ret = np.tensordot(arr_0d, arr_0d, ([], [])) # contracting no axes is well defined
40254026
assert_array_equal(ret, arr_0d)
4027+
4028+
4029+
class TestVecdot:
4030+
4031+
def test_vecdot(self):
4032+
arr1 = np.arange(6).reshape((2, 3))
4033+
arr2 = np.arange(3).reshape((1, 3))
4034+
4035+
actual = np.vecdot(arr1, arr2)
4036+
expected = np.array([5, 14])
4037+
4038+
assert_array_equal(actual, expected)

0 commit comments

Comments
 (0)
0