8000 [Array API] Add linalg.vecdot · pytorch/pytorch@cc40d5e · GitHub
[go: up one dir, main page]

Skip to content

Commit cc40d5e

Browse files
committed
[Array API] Add linalg.vecdot
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 ghstack-source-id: e51aaed Pull Request resolved: #70542
1 parent ce86881 commit cc40d5e

File tree

8 files changed

+103
-7
lines changed

8 files changed

+103
-7
lines changed

aten/src/ATen/core/interned_strings.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ namespace c10 {
218218
_(aten, mH) \
219219
_(aten, linalg_matrix_power) \
220220
_(aten, chain_matmul) \
221+
_(aten, linalg_vecdot) \
221222
_(aten, linalg_multi_dot) \
222223
_(aten, linalg_norm) \
223224
_(aten, linalg_vector_norm) \

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3836,6 +3836,25 @@ TransposeType to_transpose_type(const bool contig, const bool conj) {
38363836
}
38373837
} // end of anonymous namespace
38383838

3839+
Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
3840+
// Computes x^H y
3841+
if (x.dim() == 1 && y.dim() == 1) {
3842+
at::native::resize_output(out, {});
3843+
return at::vdot_out(out, x, y);
3844+
} else {
3845+
return at::sum_out(out, x.conj() * y, /*dim=*/dim);
3846+
}
3847+
}
3848+
3849+
Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
3850+
// Computes x^H y
3851+
if (x.dim() == 1 && y.dim() == 1) {
3852+
return at::vdot(x, y);
3853+
} else {
3854+
return x.conj().mul(y).sum(/*dim=*/dim);
3855+
}
3856+
}
3857+
38393858
/*
38403859
Solves the matrix equation AX = B for A triangular.
38413860
'left' If true solves AX = B, if false solves XA = B

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10844,6 +10844,13 @@
1084410844
- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
1084510845
python_module: linalg
1084610846

10847+
- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
10848+
python_module: linalg
10849+
variants: function
10850+
10851+
- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
10852+
python_module: linalg
10853+
1084710854
- func: linalg_matrix_exp(Tensor self) -> Tensor
1084810855
python_module: linalg
1084910856
variants: function

docs/source/linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ Matrix Products
8080

8181
cross
8282
matmul
83+
vecdot
8384
multi_dot
8485
householder_product
8586

torch/_torch_docs.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3429,21 +3429,34 @@ def merge_dicts(*dicts):
34293429
r"""
34303430
vdot(input, other, *, out=None) -> Tensor
34313431
3432-
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
3433-
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
3434-
first argument is used for the calculation of the dot product.
3432+
Computes the dot product of two 1D vectors along a dimension.
3433+
3434+
In symbols, this function computes
3435+
3436+
.. math::
3437+
3438+
\sum_{i=1}^n \overline{x_i}y_i.
3439+
3440+
where :math:`\overline{x_i}` denotes the conjugate for complex
3441+
vectors, and it is the identity for real vectors.
34353442
34363443
.. note::
34373444
34383445
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
34393446
of two 1D tensors with the same number of elements.
34403447
3448+
.. seealso::
3449+
3450+
:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.
3451+
34413452
Args:
34423453
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
34433454
other (Tensor): second tensor in the dot product, must be 1D.
34443455
34453456
Keyword args:
3446-
{out}
3457+
""" + fr"""
3458+
.. note:: {common_args["out"]}
3459+
""" + r"""
34473460
34483461
Example::
34493462
@@ -3455,7 +3468,7 @@ def merge_dicts(*dicts):
34553468
tensor([16.+1.j])
34563469
>>> torch.vdot(b, a)
34573470
tensor([16.-1.j])
3458-
""".format(**common_args))
3471+
""")
34593472

34603473
add_docstr(torch.eig,
34613474
r"""

torch/linalg/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2251,3 +2251,38 @@
22512251
>>> torch.dist(Q.mT @ Q, torch.eye(4))
22522252
tensor(6.2158e-07)
22532253
""")
2254+
2255+
vecdot = _add_docstr(_linalg.linalg_vecdot, r"""
2256+
linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor
2257+
2258+
Computes the dot product of two batches of vectors along a dimension.
2259+
2260+
In symbols, this function computes
2261+
2262+
.. math::
2263+
2264+
\sum_{i=1}^n \overline{x_i}y_i.
2265+
2266+
over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex
2267+
vectors, and it is the identity for real vectors.
2268+
2269+
Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes.
2270+
It also supports broadcasting.
2271+
2272+
Args:
2273+
x (Tensor): first batch of vectors.
2274+
y (Tensor): second batch of vectors.
2275+
2276+
Keyword args:
2277+
dim (int): Dimension along which to compute the dot product. Default: `-1`.
2278+
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
2279+
2280+
Examples::
2281+
2282+
>>> v1 = torch.randn(3, 2)
2283+
>>> v2 = torch.randn(3, 2)
2284+
>>> linalg.vecdot(v1, v2)
2285+
tensor([ 0.3223, 0.2815, -0.1944])
2286+
>>> torch.vdot(v1[0], v2[0])
2287+
tensor(0.3223)
2288+
""")

torch/overrides.py

Lines changed: 1 ad F552 dition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
869869
torch.ravel: lambda input: -1,
870870
torch.real: lambda input, out=None: -1,
871871
torch.vdot: lambda input, other, out=None: -1,
872+
torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
872873
torch.view_as_real: lambda input: -1,
873874
torch.view_as_complex: lambda input: -1,
874875
torch.reciprocal: lambda input, out=None: -1,

torch/testing/_internal/common_methods_invocations.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs):
968968
supports_multiple_dims: bool = kwargs.get('supports_multiple_dims', True)
969969

970970
# TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo
971-
# use op_info.genearte_args_kwargs directly.
971+
# use op_info.generate_args_kwargs directly.
972972
generate_args_kwargs = kwargs.get('generate_args_kwargs', lambda *args, **kwargs: (yield tuple(), {}))
973973

974974
inputs: List[SampleInput] = []
@@ -1101,7 +1101,7 @@ class ReductionOpInfo(OpInfo):
11011101
the optional keyword parameters of the ReductionOpInfo constructor.
11021102

11031103
If a reduction operator does not yet implement the full required API of
1104-
reduction operators, this should be documented by skipping the failing
1104+
reduction operators, this should be documented by xfailing the failing
11051105
tests rather than adding optional parameters to ReductionOpInfo.
11061106

11071107
NOTE
@@ -2012,6 +2012,15 @@ def sample_inputs_isclose(
20122012
yield SampleInput(lhs, args=(rhs,),
20132013
kwargs=dict(op_kwargs, rtol=rtol, atol=atol, equal_nan=equal_nan))
20142014

2015+
def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
2016+
yield from sample_inputs_binary_pwise(op_info, device, dtype, requires_grad)
2017+
2018+
# Add also samples with dim != -1
2019+
for s in sample_inputs_binary_pwise(op_info, device, dtype, requires_grad):
2020+
if s.input.ndim > 1:
2021+
s.kwargs["dim"] = 0
2022+
yield s
2023+
20152024
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
20162025
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
20172026
return (SampleInput(make_arg((1, 2))),
@@ -9778,6 +9787,16 @@ def ref_pairwise_distance(input1, input2):
97789787
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
97799788
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
97809789
),
9790+
OpInfo('linalg.vecdot',
9791+
aten_name='linalg_vecdot',
9792+
ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
9793+
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
9794+
sample_inputs_func=sample_inputs_linalg_vecdot,
9795+
supports_forward_ad=True,
9796+
supports_fwgrad_bwgrad=True,
9797+
skips=(
9798+
# FIXME torch.sum(out=) has an incorrect behaviour
9799+
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),),),
97819800
OpInfo('linalg.cond',
97829801
aten_name='linalg_cond',
97839802
dtypes=floating_and_complex_types(),

0 commit comments

Comments
 (0)
0