8000 Merge pull request #25675 from mhvk/matvec-vecmat-ufuncs · numpy/numpy@cd84af2 · GitHub
[go: up one dir, main page]

Skip to content

Commit cd84af2

Browse files
authored
Merge pull request #25675 from mhvk/matvec-vecmat-ufuncs
ENH: add matvec and vecmat gufuncs
2 parents c31bd0b + 8cec646 commit cd84af2

File tree

12 files changed

+449
-54
lines changed

12 files changed

+449
-54
lines changed

benchmarks/benchmarks/bench_ufunc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
'isinf', 'isnan', 'isnat', 'lcm', 'ldexp', 'left_shift', 'less',
1717
'less_equal', 'log', 'log10', 'log1p', 'log2', 'logaddexp',
1818
'logaddexp2', 'logical_and', 'logical_not', 'logical_or',
19-
'logical_xor', 'matmul', 'maximum', 'minimum', 'mod', 'modf',
20-
'multiply', 'negative', 'nextafter', 'not_equal', 'positive',
19+
'logical_xor', 'matmul', 'matvec', 'maximum', 'minimum', 'mod',
20+
'modf', 'multiply', 'negative', 'nextafter', 'not_equal', 'positive',
2121
'power', 'rad2deg', 'radians', 'reciprocal', 'remainder',
2222
'right_shift', 'rint', 'sign', 'signbit', 'sin',
2323
'sinh', 'spacing', 'sqrt', 'square', 'subtract', 'tan', 'tanh',
24-
'true_divide', 'trunc', 'vecdot']
24+
'true_divide', 'trunc', 'vecdot', 'vecmat']
2525
arrayfuncdisp = ['real', 'round']
2626

2727
for name in ufuncs:
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
New functions for matrix-vector and vector-matrix products
2+
----------------------------------------------------------
3+
4+
Two new generalized ufuncs were defined:
5+
6+
* `numpy.matvec` - matrix-vector product, treating the arguments as
7+
stacks of matrices and column vectors, respectively.
8+
9+
* `numpy.vecmat` - vector-matrix product, treating the arguments as
10+
stacks of column vectors and matrices, respectively. For complex
11+
vectors, the conjugate is taken.
12+
13+
These add to the existing `numpy.matmul` as well as to `numpy.vecdot`,
14+
which was added in numpy 2.0.
15+
16+
Note that `numpy.matmul` never takes a complex conjugate, also not
17+
when its left input is a vector, while both `numpy.vecdot` and
18+
`numpy.vecmat` do take the conjugate for complex vectors on the
19+
left-hand side (which are taken to be the ones that are transposed,
20+
following the physics convention).

doc/source/reference/routines.linalg.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ Matrix and vector products
6262
outer
6363
matmul
6464
linalg.matmul (Array API compatible location)
65+
matvec
66+
vecmat
6567
tensordot
6668
linalg.tensordot (Array API compatible location)
6769
einsum

numpy/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@
151151
left_shift, less, less_equal, lexsort, linspace, little_endian, log,
152152
log10, log1p, log2, logaddexp, logaddexp2, logical_and, logical_not,
153153
logical_or, logical_xor, logspace, long, longdouble, longlong, matmul,
154-
matrix_transpose, max, maximum, may_share_memory, mean, memmap, min,
155-
min_scalar_type, minimum, mod, modf, moveaxis, multiply, nan, ndarray,
156-
ndim, nditer, negative, nested_iters, newaxis, nextafter, nonzero,
157-
not_equal, number, object_, ones, ones_like, outer, partition,
154+
matvec, matrix_transpose, max, maximum, may_share_memory, mean, memmap,
155+
min, min_scalar_type, minimum, mod, modf, moveaxis, multiply, nan,
156+
ndarray, ndim, nditer, negative, nested_iters, newaxis, nextafter,
157+
nonzero, not_equal, number, object_, ones, ones_like, outer, partition,
158158
permute_dims, pi, positive, pow, power, printoptions, prod,
159159
promote_types, ptp, put, putmask, rad2deg, radians, ravel, recarray,
160160
reciprocal, record, remainder, repeat, require, reshape, resize,
@@ -165,8 +165,8 @@
165165
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
166166
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
167167
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
168-
ulonglong, unsignedinteger, unstack, ushort, var, vdot, vecdot, void,
169-
vstack, where, zeros, zeros_like
168+
ulonglong, unsignedinteger, unstack, ushort, var, vdot, vecdot,
169+
vecmat, void, vstack, where, zeros, zeros_like
170170
)
171171

172172
# NOTE: It's still under discussion whether these aliases

numpy/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4490,6 +4490,7 @@ logical_not: _UFunc_Nin1_Nout1[L['logical_not'], L[20], None]
44904490
logical_or: _UFunc_Nin2_Nout1[L['logical_or'], L[20], L[False]]
44914491
logical_xor: _UFunc_Nin2_Nout1[L['logical_xor'], L[19], L[False]]
44924492
matmul: _GUFunc_Nin2_Nout1[L['matmul'], L[19], None, L["(n?,k),(k,m?)->(n?,m?)"]]
4493+
matvec: _GUFunc_Nin2_Nout1[L['matvec'], L[19], None, L["(m,n),(n)->(m)"]]
44934494
maximum: _UFunc_Nin2_Nout1[L['maximum'], L[21], None]
44944495
minimum: _UFunc_Nin2_Nout1[L['minimum'], L[21], None]
44954496
mod: _UFunc_Nin2_Nout1[L['remainder'], L[16], None]
@@ -4519,6 +4520,7 @@ tanh: _UFunc_Nin1_Nout1[L['tanh'], L[8], None]
45194520
true_divide: _UFunc_Nin2_Nout1[L['true_divide'], L[11], None]
45204521
trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]
45214522
vecdot: _GUFunc_Nin2_Nout1[L['vecdot'], L[19], None, L["(n),(n)->()"]]
4523+
vecmat: _GUFunc_Nin2_Nout1[L['vecmat'], L[19], None, L["(n),(n,m)->(m)"]]
45224524

45234525
abs = absolute
45244526
acos = arccos

numpy/_core/code_generators/generate_umath.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,6 +1152,22 @@ def english_upper(s):
11521152
TD(O),
11531153
signature='(n),(n)->()',
11541154
),
1155+
'matvec':
1156+
Ufunc(2, 1, None,
1157+
docstrings.get('numpy._core.umath.matvec'),
1158+
" 10000 PyUFunc_SimpleUniformOperationTypeResolver",
1159+
TD(notimes_or_obj),
1160+
TD(O),
1161+
signature='(m,n),(n)->(m)',
1162+
),
1163+
'vecmat':
1164+
Ufunc(2, 1, None,
1165+
docstrings.get('numpy._core.umath.vecmat'),
1166+
"PyUFunc_SimpleUniformOperationTypeResolver",
1167+
TD(notimes_or_obj),
1168+
TD(O),
1169+
signature='(n),(n,m)->(m)',
1170+
),
11551171
'str_len':
11561172
Ufunc(1, 1, Zero,
11571173
docstrings.get('numpy._core.umath.str_len'),

numpy/_core/code_generators/ufunc_docstrings.py

Lines changed: 140 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def add_newdoc(place, name, doc):
4444

4545
skip = (
4646
# gufuncs do not use the OUT_SCALAR replacement strings
47-
'matmul', 'vecdot',
47+
'matmul', 'vecdot', 'matvec', 'vecmat',
4848
# clip has 3 inputs, which is not handled by this
4949
'clip',
5050
)
@@ -2793,7 +2793,9 @@ def add_newdoc(place, name, doc):
27932793
27942794
See Also
27952795
--------
2796-
vdot : Complex-conjugating dot product.
2796+
vecdot : Complex-conjugating dot product for stacks of vectors.
2797+
matvec : Matrix-vector product for stacks of matrices and vectors.
2798+
vecmat : Vector-matrix product for stacks of vectors and matrices.
27972799
tensordot : Sum products over arbitrary axes.
27982800
einsum : Einstein summation convention.
27992801
dot : alternative matrix product with different broadcasting rules.
@@ -2808,10 +2810,10 @@ def add_newdoc(place, name, doc):
28082810
matrices residing in the last two indexes and broadcast accordingly.
28092811
- If the first argument is 1-D, it is promoted to a matrix by
28102812
prepending a 1 to its dimensions. After matrix multiplication
2811-
the prepended 1 is removed.
2813+
the prepended 1 is removed. (For stacks of vectors, use ``vecmat``.)
28122814
- If the second argument is 1-D, it is promoted to a matrix by
28132815
appending a 1 to its dimensions. After matrix multiplication
2814-
the appended 1 is removed.
2816+
the appended 1 is removed. (For stacks of vectors, use ``matvec``.)
28152817
28162818
``matmul`` differs from ``dot`` in two important ways:
28172819
@@ -2910,8 +2912,8 @@ def add_newdoc(place, name, doc):
29102912
Input arrays, scalars not allowed.
29112913
out : ndarray, optional
29122914
A location into which the result is stored. If provided, it must have
2913-
a shape that the broadcasted shape of `x1` and `x2` with the last axis
2914-
removed. If not provided or None, a freshly-allocated array is used.
2915+
the broadcasted shape of `x1` and `x2` with the last axis removed.
2916+
If not provided or None, a freshly-allocated array is used.
29152917
**kwargs
29162918
For other keyword-only arguments, see the
29172919
:ref:`ufunc docs <ufuncs.kwargs>`.
@@ -2933,6 +2935,9 @@ def add_newdoc(place, name, doc):
29332935
See Also
29342936
--------
29352937
vdot : same but flattens arguments first
2938+
matmul : Matrix-matrix product.
2939+
vecmat : Vector-matrix product.
2940+
matvec : Matrix-vector product.
29362941
einsum : Einstein summation convention.
29372942
29382943
Examples
@@ -2949,6 +2954,135 @@ def add_newdoc(place, name, doc):
29492954
.. versionadded:: 2.0.0
29502955
""")
29512956

2957+
add_newdoc('numpy._core.umath', 'matvec',
2958+
"""
2959+
Matrix-vector dot product of two arrays.
2960+
2961+
Given a matrix (or stack of matrices) :math:`\\mathbf{A}` in ``x1`` and
2962+
a vector (or stack of vectors) :math:`\\mathbf{v}` in ``x2``, the
2963+
matrix-vector product is defined as:
2964+
2965+
.. math::
2966+
\\mathbf{A} \\cdot \\mathbf{b} = \\sum_{j=0}^{n-1} A_{ij} v_j
2967+
2968+
where the sum is over the last dimensions in ``x1`` and ``x2``
2969+
(unless ``axes`` is specified). (For a matrix-vector product with the
2970+
vector conjugated, use ``np.vecmat(x2, x1.mT)``.)
2971+
2972+
Parameters
2973+
----------
2974+
x1, x2 : array_like
2975+
Input arrays, scalars not allowed.
2976+
out : ndarray, optional
2977+
A location into which the result is stored. If provided, it must have
2978+
the broadcasted shape of ``x1`` and ``x2`` with the summation axis
2979+
removed. If not provided or None, a freshly-allocated array is used.
2980+
**kwargs
2981+
For other keyword-only arguments, see the
2982+
:ref:`ufunc docs <ufuncs.kwargs>`.
2983+
2984+
Returns
2985+
-------
2986+
y : ndarray
2987+
The matrix-vector product of the inputs.
2988+
2989+
Raises
2990+
------
2991+
ValueError
2992+
If the last dimensions of ``x1`` and ``x2`` are not the same size.
2993+
2994+
If a scalar value is passed in.
2995+
2996+
See Also
2997+
--------
2998+
vecdot : Vector-vector product.
2999+
vecmat : Vector-matrix product.
3000+
matmul : Matrix-matrix product.
3001+
einsum : Einstein summation convention.
3002+
3003+
Examples
3004+
--------
3005+
Rotate a set of vectors from Y to X along Z.
3006+
3007+
>>> a = np.array([[0., 1., 0.],
3008+
... [-1., 0., 0.],
3009+
... [0., 0., 1.]])
3010+
>>> v = np.array([[1., 0., 0.],
3011+
... [0., 1., 0.],
3012+
... [0., 0., 1.],
3013+
... [0., 6., 8.]])
3014+
< 10000 span class=pl-s> >>> np.matvec(a, v)
3015+
array([[ 0., -1., 0.],
3016+
[ 1., 0., 0.],
3017+
[ 0., 0., 1.],
3018+
[ 6., 0., 8.]])
3019+
3020+
.. versionadded:: 2.1.0
3021+
""")
3022+
3023+
add_newdoc('numpy._core.umath', 'vecmat',
3024+
"""
3025+
Vector-matrix dot product of two arrays.
3026+
3027+
Given a vector (or stack of vector) :math:`\\mathbf{v}` in ``x1`` and
3028+
a matrix (or stack of matrices) :math:`\\mathbf{A}` in ``x2``, the
3029+
vector-matrix product is defined as:
3030+
3031+
.. math::
3032+
\\mathbf{b} \\cdot \\mathbf{A} = \\sum_{i=0}^{n-1} \\overline{v_i}A_{ij}
3033+
3034+
where the sum is over the last dimension of ``x1`` and the one-but-last
3035+
dimensions in ``x2`` (unless `axes` is specified) and where
3036+
:math:`\\overline{v_i}` denotes the complex conjugate if :math:`v`
3037+
is complex and the identity otherwise. (For a non-conjugated vector-matrix
3038+
product, use ``np.matvec(x2.mT, x1)``.)
3039+
3040+
Parameters
3041+
----------
3042+
x1, x2 : array_like
3043+
Input arrays, scalars not allowed.
3044+
out : ndarray, optional
3045+
A location into which the result is stored. If provided, it must have
3046+
the broadcasted shape of ``x1`` and ``x2`` with the summation axis
3047+
removed. If not provided or None, a freshly-allocated array is used.
3048+
**kwargs
3049+
For other keyword-only arguments, see the
3050+
:ref:`ufunc docs <ufuncs.kwargs>`.
3051+
3052+
Returns
3053+
-------
3054+
y : ndarray
3055+
The vector-matrix product of the inputs.
3056+
3057+
Raises
3058+
------
3059+
ValueError
3060+
If the last dimensions of ``x1`` and the one-but-last dimension of
3061+
``x2`` are not the same size.
3062+
3063+
If a scalar value is passed in.
3064+
3065+
See Also
3066+
--------
3067+
vecdot : Vector-vector product.
3068+
matvec : Matrix-vector product.
3069+
matmul : Matrix-matrix product.
3070+
einsum : Einstein summation convention.
3071+
3072+
Examples
3073+
--------
3074+
Project a vector along X and Y.
3075+
3076+
>>> v = np.array([0., 4., 2.])
3077+
>>> a = np.array([[1., 0., 0.],
3078+
... [0., 1., 0.],
3079+
... [0., 0., 0.]])
3080+
>>> np.vecmat(v, a)
3081+
array([ 0., 4., 0.])
3082+
3083+
.. versionadded:: 2.1.0
3084+
""")
3085+
29523086
add_newdoc('numpy._core.umath', 'modf',
29533087
"""
29543088
Return the fractional and integral parts of an array, element-wise.

numpy/_core/multiarray.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ def _override___module__():
8383
'isfinite', 'isinf', 'isnan', 'isnat', 'lcm', 'ldexp', 'less',
8484
'less_equal', 'log', 'log10', 'log1p', 'log2', 'logaddexp',
8585
'logaddexp2', 'logical_and', 'logical_not', 'logical_or',
86-
'logical_xor', 'matmul', 'maximum', 'minimum', 'remainder', 'modf',
87-
'multiply', 'negative', 'nextafter', 'not_equal', 'positive', 'power',
88-
'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit', 'sin',
89-
'sinh', 'spacing', 'sqrt', 'square', 'subtract', 'tan', 'tanh',
90-
'trunc', 'vecdot',
86+
'logical_xor', 'matmul', 'matvec', 'maximum', 'minimum', 'remainder',
87+
'modf', 'multiply', 'negative', 'nextafter', 'not_equal', 'positive',
88+
'power', 'rad2deg', 'radians', 'reciprocal', 'rint', 'sign', 'signbit',
89+
'sin', 'sinh', 'spacing', 'sqrt', 'square', 'subtract', 'tan', 'tanh',
90+
'trunc', 'vecdot', 'vecmat',
9191
]:
9292
ufunc = namespace_names[ufunc_name]
9393
ufunc.__module__ = "numpy"

0 commit comments

Comments
 (0)
0