8000 Update on "[Array API] Add linalg.vecdot" · pytorch/pytorch@a426304 · GitHub
[go: up one dir, main page]

Skip to content

Commit a426304

Browse files
committed
Update on "[Array API] Add linalg.vecdot"
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this. Resolves #18027. cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi [ghstack-poisoned]
1 parent 69c1492 commit a426304

File tree

4 files changed

+38
-22
lines changed

4 files changed

+38
-22
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3837,11 +3837,22 @@ TransposeType to_transpose_type(const bool contig, const bool conj) {
38373837
} // end of anonymous namespace
38383838

38393839
Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
3840-
return at::sum_out(out, x * y, /*dim=*/dim);
3840+
// Computes x^H y
3841+
if (x.dim() == 1 && y.dim() == 1) {
3842+
at::native::resize_output(out, {});
3843+
return at::vdot_out(out, x, y);
3844+
} else {
< 8000 code>3845+
return at::sum_out(out, x.conj() * y, /*dim=*/dim);
3846+
}
38413847
}
38423848

38433849
Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
3844-
return (x * y).sum(/*dim=*/dim);
3850+
// Computes x^H y
3851+
if (x.dim() == 1 && y.dim() == 1) {
3852+
return at::vdot(x, y);
3853+
} else {
3854+
return x.conj().mul(y).sum(/*dim=*/dim);
3855+
}
38453856
}
38463857

38473858
/*

torch/_torch_docs.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3412,10 +3412,6 @@ def merge_dicts(*dicts):
34123412
Unlike NumPy's dot, torch.dot intentionally only supports computing the dot product
34133413
of two 1D tensors with the same number of elements.
34143414
3415-
.. seealso::
3416-
3417-
:func:`torch.linalg.vecdot` the vector product of two batches of vectors along a dimension.
3418-
34193415
Args:
34203416
input (Tensor): first tensor in the dot product, must be 1D.
34213417
other (Tensor): second tensor in the dot product, must be 1D.
@@ -3433,21 +3429,34 @@ def merge_dicts(*dicts):
34333429
r"""
34343430
vdot(input, other, *, out=None) -> Tensor
34353431
3436-
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
3437-
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
3438-
first argument is used for the calculation of the dot product.
3432+
Computes the dot product of two 1D vectors along a dimension.
3433+
3434+
In symbols, this function computes
3435+
3436+
.. math::
3437+
3438+
\sum_{i=1}^n \overline{x_i}y_i.
3439+
3440+
where :math:`\overline{x_i}` denotes the conjugate for complex
3441+
vectors, and it is the identity for real vectors.
34393442
34403443
.. note::
34413444
34423445
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
34433446
of two 1D tensors with the same number of elements.
34443447
3448+
.. seealso::
3449+
3450+
:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.
3451+
34453452
Args:
34463453
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
34473454
other (Tensor): second tensor in the dot product, must be 1D.
34483455
34493456
Keyword args:
3450-
{out}
3457+
""" + fr"""
3458+
.. note:: {common_args["out"]}
3459+
""" + r"""
34513460
34523461
Example::
34533462
@@ -3459,7 +3468,7 @@ def merge_dicts(*dicts):
34593468
tensor([16.+1.j])
34603469
>>> torch.vdot(b, a)
34613470
tensor([16.-1.j])
3462-
""".format(**common_args))
3471+
""")
34633472

34643473
add_docstr(torch.eig,
34653474
r"""

torch/linalg/__init__.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2255,30 +2255,26 @@
22552255
vecdot = _add_docstr(_linalg.linalg_vecdot, r"""
22562256
linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor
22572257
2258-
Computes the real dot product of two batches of vectors along a dimension.
2258+
Computes the dot product of two batches of vectors along a dimension.
22592259
22602260
In symbols, this function computes
22612261
22622262
.. math::
22632263
2264-
\sum_{i=1}^n x_iy_i.
2264+
\sum_{i=1}^n \overline{x_i}y_i.
22652265
2266-
over the dimension :attr:`dim`.
2266+
over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex
2267+
vectors, and it is the identity for real vectors.
22672268
22682269
Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes.
22692270
It also supports broadcasting.
22702271
2271-
.. seealso::
2272-
2273-
:func:`torch.matmul` computes a general matrix-matrix multiplication for batches
2274-
of matrices.
2275-
22762272
Args:
22772273
x (Tensor): first batch of vectors.
22782274
y (Tensor): second batch of vectors.
22792275
22802276
Keyword args:
2281-
dim (int): Dimension along which to compute the real dot product. Default: `-1`.
2277+
dim (int): Dimension along which to compute the dot product. Default: `-1`.
22822278
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
22832279
22842280
Examples::
@@ -2287,6 +2283,6 @@
22872283
>>> v2 = torch.randn(3, 2)
22882284
>>> linalg.vecdot(v1, v2)
22892285
tensor([ 0.3223, 0.2815, -0.1944])
2290-
>>> torch.dot(v1[0], v2[0])
2286+
>>> torch.vdot(v1[0], v2[0])
22912287
tensor(0.3223)
22922288
""")

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9789,7 +9789,7 @@ def ref_pairwise_distance(input1, input2):
97899789
),
97909790
OpInfo('linalg.vecdot',
97919791
aten_name='linalg_vecdot',
9792-
ref=lambda x, y, *, dim=-1: (x * y).sum(dim),
9792+
ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
97939793
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
97949794
sample_inputs_func=sample_inputs_linalg_vecdot,
97959795
supports_forward_ad=True,

0 commit comments

Comments
 (0)
0