8000 [Array API] Add linalg.vecdot · pytorch/pytorch@aa43a44 · GitHub
[go: up one dir, main page]

Skip to content

Commit aa43a44

Browse files
committed
[Array API] Add linalg.vecdot
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 ghstack-source-id: 13976c0 Pull Request resolved: #70542
1 parent 28776c4 commit aa43a44

File tree

9 files changed

+119
-6
lines changed

9 files changed

+119
-6
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4102,6 +4102,39 @@ TORCH_IMPL_FUNC(linalg_ldl_solve_out)
41024102

41034103
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41044104

4105+
Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
4106+
checkFloatingOrComplex(x, "linalg.vecdot");
4107+
TORCH_CHECK(x.scalar_type() == y.scalar_type(),
4108+
"linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
4109+
x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
4110+
// out checks
4111+
TORCH_CHECK(out.scalar_type() == x.scalar_type(),
4112+
"linalg.vecdot: Expected out of dtype", x.scalar_type(),
4113+
" but found ", out.scalar_type());
4114+
checkSameDevice("linalg.vecdot", x, out);
4115+
4116+
// Computes x^H y
4117+
if (x.dim() == 1 && y.dim() == 1) {
4118+
at::native::resize_output(out, {});
4119+
return at::vdot_out(out, x, y);
4120+
} else {
4121+
return at::sum_out(out, x.conj() * y, /*dim=*/dim);
4122+
}
4123+
}
4124+
4125+
Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
4126+
checkFloatingOrComplex(x, "linalg.vecdot");
4127+
TORCH_CHECK(x.scalar_type() == y.scalar_type(),
4128+
"linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
4129+
x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
4130+
// Computes x^H y
4131+
if (x.dim() == 1 && y.dim() == 1) {
4132+
return at::vdot(x, y);
4133+
} else {
4134+
return x.conj().mul(y).sum(/*dim=*/dim);
4135+
}
4136+
}
4137+
41054138
/*
41064139
Solves the matrix equation AX = B for A triangular.
41074140
'left' If true solves AX = B, if false solves XA = B

aten/src/ATen/native/native_functions.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11612,6 +11612,13 @@
1161211612
- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
1161311613
python_module: linalg
1161411614

11615+
- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
11616+
python_module: linalg
11617+
variants: function
11618+
11619+
- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
11620+
python_module: linalg
11621+
1161511622
- func: linalg_matrix_exp(Tensor self) -> Tensor
1161611623
python_module: linalg
1161711624
variants: function

docs/source/linalg.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Matrix Products
8484

8585
cross
8686
matmul
87+
vecdot
8788
multi_dot
8889
householder_product
8990

test/allowlist_for_publicAPI.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1148,6 +1148,7 @@
11481148
"tensorinv",
11491149
"tensorsolve",
11501150
"vander",
1151+
"vecdot",
11511152
"vector_norm"
11521153
],
11531154
"torch.multiprocessing": [

test/test_meta.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def run_meta_crossref(
497497
torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason?
498498
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
499499
torch.linalg.pinv: {f32, f64},
500+
torch.linalg.vecdot: {f16, bf16, f32, f64}, # aten::prod
500501
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
501502
}
502503

torch/_torch_docs.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3513,21 +3513,34 @@ def merge_dicts(*dicts):
35133513
r"""
35143514
vdot(input, other, *, out=None) -> Tensor
35153515
3516-
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
3517-
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
3518-
first argument is used for the calculation of the dot product.
3516+
Computes the dot product of two 1D vectors along a dimension.
3517+
3518+
In symbols, this function computes
3519+
3520+
.. math::
3521+
3522+
\sum_{i=1}^n \overline{x_i}y_i.
3523+
3524+
where :math:`\overline{x_i}` denotes the conjugate for complex
3525+
vectors, and it is the identity for real vectors.
35193526
35203527
.. note::
35213528
35223529
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
35233530
of two 1D tensors with the same number of elements.
35243531
3532+
.. seealso::
3533+
3534+
:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.
3535+
35253536
Args:
35263537
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
35273538
other (Tensor): second tensor in the dot product, must be 1D.
35283539
35293540
Keyword args:
3530-
{out}
3541+
""" + fr"""
3542+
.. note:: {common_args["out"]}
3543+
""" + r"""
35313544
35323545
Example::
35333546
@@ -3539,7 +3552,7 @@ def merge_dicts(*dicts):
35393552
tensor([16.+1.j])
35403553
>>> torch.vdot(b, a)
35413554
tensor([16.-1.j])
3542-
""".format(**common_args))
3555+
""")
35433556

35443557
add_docstr(torch.eig,
35453558
r"""

torch/linalg/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,3 +2801,38 @@
28012801
[ 1, 3, 9],
28022802
[ 1, 5, 25]])
28032803
""")
2804+
2805+
vecdot = _add_docstr(_linalg.linalg_vecdot, r"""
2806+
linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor
2807+
2808+
Computes the dot product of two batches of vectors along a dimension.
2809+
2810+
In symbols, this function computes
2811+
2812+
.. math::
2813+
2814+
\sum_{i=1}^n \overline{x_i}y_i.
2815+
2816+
over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex
2817+
vectors, and it is the identity for real vectors.
2818+
2819+
Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes.
2820+
It also supports broadcasting.
2821+
2822+
Args:
2823+
x (Tensor): first batch of vectors of shape `(*, n)`.
2824+
y (Tensor): second batch of vectors of shape `(*, n)`.
2825+
2826+
Keyword args:
2827+
dim (int): Dimension along which to compute the dot product. Default: `-1`.
2828+
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
2829+
2830+
Examples::
2831+
2832+
>>> v1 = torch.randn(3, 2)
2833+
>>> v2 = torch.randn(3, 2)
2834+
>>> linalg.vecdot(v1, v2)
2835+
tensor([ 0.3223, 0.2815, -0.1944])
2836+
>>> torch.vdot(v1[0], v2[0])
2837+
tensor(0.3223)
2838+
""")

torch/overrides.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
923923
torch.ravel: lambda input: -1,
924924
torch.real: lambda input, out=None: -1,
925925
torch.vdot: lambda input, other, out=None: -1,
926+
torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
926927
torch.view_as_real: lambda input: -1,
927928
torch.view_as_complex: lambda input: -1,
928929
torch.reciprocal: lambda input, out=None: -1,

torch/testing/_internal/common_methods_invocations.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,7 @@ class ReductionOpInfo(OpInfo):
13511351
the optional keyword parameters of the ReductionOpInfo constructor.
13521352

13531353
If a reduction operator does not yet implement the full required API of
1354-
reduction operators, this should be documented by skipping the failing
1354+
reduction operators, this should be documented by xfailing the failing
13551355
tests rather than adding optional parameters to ReductionOpInfo.
13561356

13571357
NOTE
@@ -3045,6 +3045,17 @@ def error_inputs_isclose(op, device, **kwargs):
30453045
error_regex='atol must be greater than or equal to zero')
30463046

30473047

3048+
def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
3049+
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3050+
batches = ((), (0,), (1,), (5,))
3051+
ns = (0, 1, 3, 5)
3052+
for b, n in product(batches, ns):
3053+
shape = b + (n,)
3054+
yield SampleInput(make_arg(shape), args=(make_arg(shape),))
3055+
for i in range(len(shape)):
3056+
yield SampleInput(make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i))
3057+
3058+
30483059
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
30493060
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
30503061
return (SampleInput(make_arg((1, 2))),
@@ -12464,6 +12475,16 @@ def error_inputs_mean(op_info, device, **kwargs):
1246412475
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
1246512476
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1246612477
),
12478+
OpInfo('linalg.vecdot',
12479+
aten_name='linalg_vecdot',
12480+
ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
12481+
dtypes=floating_and_complex_types_and(torch.bfloat16),
12482+
dtypesIfCUDA=floating_and_complex_types_and(torch.half,
12483+
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
12484+
sample_inputs_func=sample_inputs_linalg_vecdot,
12485+
check_batched_forward_grad=False,
12486+
supports_forward_ad=True,
12487+
supports_fwgrad_bwgrad=True),
1246712488
OpInfo('linalg.cond',
1246812489
aten_name='linalg_cond',
1246912490
dtypes=floating_and_complex_types(),

0 commit comments

Comments
 (0)
0