8000 [Array API] Add linalg.vecdot by lezcano · Pull Request #70542 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Array API] Add linalg.vecdot #70542

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4102,6 +4102,39 @@ TORCH_IMPL_FUNC(linalg_ldl_solve_out)

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
checkFloatingOrComplex(x, "linalg.vecdot");
TORCH_CHECK(x.scalar_type() == y.scalar_type(),
"linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
// out checks
TORCH_CHECK(out.scalar_type() == x.scalar_type(),
"linalg.vecdot: Expected out of dtype", x.scalar_type(),
" but found ", out.scalar_type());
checkSameDevice("linalg.vecdot", x, out);

// Computes x^H y
if (x.dim() == 1 && y.dim() == 1) {
at::native::resize_output(out, {});
return at::vdot_out(out, x, y);
} else {
return at::sum_out(out, x.conj() * y, /*dim=*/dim);
}
}

Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
checkFloatingOrComplex(x, "linalg.vecdot");
TORCH_CHECK(x.scalar_type() == y.scalar_type(),
"linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
// Computes x^H y
if (x.dim() == 1 && y.dim() == 1) {
return at::vdot(x, y);
} else {
return x.conj().mul(y).sum(/*dim=*/dim);
}
}

/*
Solves the matrix equation AX = B for A triangular.
'left' If true solves AX = B, if false solves XA = B
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11612,6 +11612,13 @@
- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
python_module: linalg

- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
python_module: linalg
variants: function

- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
python_module: linalg

- func: linalg_matrix_exp(Tensor self) -> Tensor
python_module: linalg
variants: function
Expand Down
1 change: 1 addition & 0 deletions docs/source/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ Matrix Products

cross
matmul
vecdot
multi_dot
householder_product

Expand Down
1 change: 1 addition & 0 deletions test/allowlist_for_publicAPI.json
Original file line number Diff line number Diff line change
Expand Up @@ -1148,6 +1148,7 @@
"tensorinv",
"tensorsolve",
"vander",
"vecdot",
"vector_norm"
],
"torch.multiprocessing": [
Expand Down
1 change: 1 addition & 0 deletions test/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def run_meta_crossref(
torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason?
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
torch.linalg.pinv: {f32, f64},
torch.linalg.vecdot: {f16, bf16, f32, f64}, # aten::prod
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
}

Expand Down
23 changes: 18 additions & 5 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3513,21 +3513,34 @@ def merge_dicts(*dicts):
r"""
vdot(input, other, *, out=None) -> Tensor

Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
first argument is used for the calculation of the dot product.
Computes the dot product of two 1D vectors along a dimension.

In symbols, this function computes

.. math::

\sum_{i=1}^n \overline{x_i}y_i.

where :math:`\overline{x_i}` denotes the conjugate for complex
vectors, and it is the identity for real vectors.

.. note::

Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
of two 1D tensors with the same number of elements.

.. seealso::

:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.

Args:
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
other (Tensor): second tensor in the dot product, must be 1D.

Keyword args:
{out}
""" + fr"""
.. note:: {common_args["out"]}
""" + r"""

Example::

Expand All @@ -3539,7 +3552,7 @@ def merge_dicts(*dicts):
tensor([16.+1.j])
>>> torch.vdot(b, a)
tensor([16.-1.j])
""".format(**common_args))
""")

add_docstr(torch.eig,
r"""
Expand Down
35 changes: 35 additions & 0 deletions torch/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2801,3 +2801,38 @@
[ 1, 3, 9],
[ 1, 5, 25]])
""")

vecdot = _add_docstr(_linalg.linalg_vecdot, r"""
linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor

Computes the dot product of two batches of vectors along a dimension.

In symbols, this function computes

.. math::

\sum_{i=1}^n \overline{x_i}y_i.

over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex
vectors, and it is the identity for real vectors.

Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it supports everything do we need to mention it?

I think it's fine to be explicit here. Also, many other linalg functions have a note on supported dtypes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving it as "better explicit than implicit". Also, it does support bool, but it's not announced here. If I had to choose I'd disallow integral and boolean types, but left them because why not.

It also supports broadcasting.

Args:
x (Tensor): first batch of vectors of shape `(*, n)`.
y (Tensor): second batch of vectors of shape `(*, n)`.

Keyword args:
dim (int): Dimension along which to compute the dot product. Default: `-1`.
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.

Examples::

>>> v1 = torch.randn(3, 2)
>>> v2 = torch.randn(3, 2)
>>> linalg.vecdot(v1, v2)
tensor([ 0.3223, 0.2815, -0.1944])
>>> torch.vdot(v1[0], v2[0])
tensor(0.3223)
""")
1 change: 1 addition & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.ravel: lambda input: -1,
torch.real: lambda input, out=None: -1,
torch.vdot: lambda input, other, out=None: -1,
torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
torch.view_as_real: lambda input: -1,
torch.view_as_complex: lambda input: -1,
torch.reciprocal: lambda input, out=None: -1,
Expand Down
23 changes: 22 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1351,7 +1351,7 @@ class ReductionOpInfo(OpInfo):
the optional keyword parameters of the ReductionOpInfo constructor.

If a reduction operator does not yet implement the full required API of
reduction operators, this should be documented by skipping the failing
reduction operators, this should be documented by xfailing the failing
tests rather than adding optional parameters to ReductionOpInfo.

NOTE
Expand Down Expand Up @@ -3045,6 +3045,17 @@ def error_inputs_isclose(op, device, **kwargs):
error_regex='atol must be greater than or equal to zero')


def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
batches = ((), (0,), (1,), (5,))
ns = (0, 1, 3, 5)
for b, n in product(batches, ns):
shape = b + (n,)
yield SampleInput(make_arg(shape), args=(make_arg(shape),))
for i in range(len(shape)):
yield SampleInput(make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i))


def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
return (SampleInput(make_arg((1, 2))),
Expand Down Expand Up @@ -12464,6 +12475,16 @@ def error_inputs_mean(op_info, device, **kwargs):
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
),
OpInfo('linalg.vecdot',
aten_name='linalg_vecdot',
ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
dtypes=floating_and_complex_types_and(torch.bfloat16),
dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (CUDA11OrLater or TEST_WITH_ROCM) else []),
sample_inputs_func=sample_inputs_linalg_vecdot,
check_batched_forward_grad=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True),
OpInfo('linalg.cond',
aten_name='linalg_cond',
dtypes=floating_and_complex_types(),
Expand Down
0