10000 Revert "[Array API] Add linalg.vecdot (#70542)" · pytorch/pytorch@39f659c · GitHub
[go: up one dir, main page]

Skip to content

Commit 39f659c

Browse files
Revert "[Array API] Add linalg.vecdot (#70542)"
This reverts commit 74208a9. Reverted #70542 on behalf of https://github.com/malfet due to Broke CUDA-10.2 for vecdot_bfloat16, see https://hud.pytorch.org/pytorch/pytorch/commit/74208a9c68b5892b9dde39d06350fe7b92691429
1 parent 80bf2ea commit 39f659c

File tree

9 files changed

+6
-118
lines changed

9 files changed

+6
-118
lines changed

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4102,39 +4102,6 @@ TORCH_IMPL_FUNC(linalg_ldl_solve_out)
41024102

41034103
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ solve_triangular ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
41044104

4105-
Tensor& linalg_vecdot_out(const Tensor& x, const Tensor& y, int64_t dim, Tensor& out) {
4106-
checkFloatingOrComplex(x, "linalg.vecdot");
4107-
TORCH_CHECK(x.scalar_type() == y.scalar_type(),
4108-
"linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
4109-
x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
4110-
// out checks
4111-
TORCH_CHECK(out.scalar_type() == x.scalar_type(),
4112-
"linalg.vecdot: Expected out of dtype", x.scalar_type(),
4113-
" but found ", out.scalar_type());
4114-
checkSameDevice("linalg.vecdot", x, out);
4115-
4116-
// Computes x^H y
4117-
if (x.dim() == 1 && y.dim() == 1) {
4118-
at::native::resize_output(out, {});
4119-
return at::vdot_out(out, x, y);
4120-
} else {
4121-
return at::sum_out(out, x.conj() * y, /*dim=*/dim);
4122-
}
4123-
}
4124-
4125-
Tensor linalg_vecdot(const Tensor& x, const Tensor& y, int64_t dim) {
4126-
checkFloatingOrComplex(x, "linalg.vecdot");
4127-
TORCH_CHECK(x.scalar_type() == y.scalar_type(),
4128-
"linalg.vecdot: Expected x and y to have the same dtype, but found x of type ",
4129-
x.scalar_type(), " and y of type ", y.scalar_type(), " instead");
4130-
// Computes x^H y
4131-
if (x.dim() == 1 && y.dim() == 1) {
4132-
return at::vdot(x, y);
4133-
} else {
4134-
return x.conj().mul(y).sum(/*dim=*/dim);
4135-
}
4136-
}
4137-
41384105
/*
41394106
Solves the matrix equation AX = B for A triangular.
41404107
'left' If true solves AX = B, if false solves XA = B

aten/src/ATen/native/native_functions.yaml

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11612,13 +11612,6 @@
1161211612
- func: linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
1161311613
python_module: linalg
1161411614

11615-
- func: linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor
11616-
python_module: linalg
11617-
variants: function
11618-
11619-
- func: linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)
11620-
python_module: linalg
11621-
1162211615
- func: linalg_matrix_exp(Tensor self) -> Tensor
1162311616
python_module: linalg
1162411617
variants: function

docs/source/linalg.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ Matrix Products
8484

8585
cross
8686
matmul
87-
vecdot
8887
multi_dot
8988
householder_product
9089

test/allowlist_for_publicAPI.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,6 @@
12081208
"tensorinv",
12091209
"tensorsolve",
12101210
"vander",
1211-
"vecdot",
12121211
"vector_norm"
12131212
],
12141213
"torch.multiprocessing": [

test/test_meta.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,6 @@ def run_meta_crossref(
497497
torch.nanmean: {bf16, f16, f32, f64}, # TODO(chilli): Doesn't seem to work for some reason?
498498
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
499499
torch.linalg.pinv: {f32, f64},
500-
torch.linalg.vecdot: {f16, bf16, f32, f64}, # aten::prod
501500
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
502501
}
503502

torch/_torch_docs.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3510,34 +3510,21 @@ def merge_dicts(*dicts):
35103510
r"""
35113511
vdot(input, other, *, out=None) -> Tensor
35123512
3513-
Computes the dot product of two 1D vectors along a dimension.
3514-
3515-
In symbols, this function computes
3516-
3517-
.. math::
3518-
3519-
\sum_{i=1}^n \overline{x_i}y_i.
3520-
3521-
where :math:`\overline{x_i}` denotes the conjugate for complex
3522-
vectors, and it is the identity for real vectors.
3513+
Computes the dot product of two 1D tensors. The vdot(a, b) function handles complex numbers
3514+
differently than dot(a, b). If the first argument is complex, the complex conjugate of the
3515+
first argument is used for the calculation of the dot product.
35233516
35243517
.. note::
35253518
35263519
Unlike NumPy's vdot, torch.vdot intentionally only supports computing the dot product
35273520
of two 1D tensors with the same number of elements.
35283521
3529-
.. seealso::
3530-
3531-
:func:`torch.linalg.vecdot` computes the dot product of two batches of vectors along a dimension.
3532-
35333522
Args:
35343523
input (Tensor): first tensor in the dot product, must be 1D. Its conjugate is used if it's complex.
35353524
other (Tensor): second tensor in the dot product, must be 1D.
35363525
35373526
Keyword args:
3538-
""" + fr"""
3539-
.. note:: {common_args["out"]}
3540-
""" + r"""
3527+
{out}
35413528
35423529
Example::
35433530
@@ -3549,7 +3536,7 @@ def merge_dicts(*dicts):
35493536
tensor([16.+1.j])
35503537
>>> torch.vdot(b, a)
35513538
tensor([16.-1.j])
3552-
""")
3539+
""".format(**common_args))
35533540

35543541
add_docstr(torch.eig,
35553542
r"""

torch/linalg/__init__.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2801,38 +2801,3 @@
28012801
[ 1, 3, 9],
28022802
[ 1, 5, 25]])
28032803
""")
2804-
2805-
vecdot = _add_docstr(_linalg.linalg_vecdot, r"""
2806-
linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor
2807-
2808-
Computes the dot product of two batches of vectors along a dimension.
2809-
2810-
In symbols, this function computes
2811-
2812-
.. math::
2813-
2814-
\sum_{i=1}^n \overline{x_i}y_i.
2815-
2816-
over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex
2817-
vectors, and it is the identity for real vectors.
2818-
2819-
Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes.
2820-
It also supports broadcasting.
2821-
2822-
Args:
2823-
x (Tensor): first batch of vectors of shape `(*, n)`.
2824-
y (Tensor): second batch of vectors of shape `(*, n)`.
2825-
2826-
Keyword args:
2827-
dim (int): Dimension along which to compute the dot product. Default: `-1`.
2828-
out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.
2829-
2830-
Examples::
2831-
2832-
>>> v1 = torch.randn(3, 2)
2833-
>>> v2 = torch.randn(3, 2)
2834-
>>> linalg.vecdot(v1, v2)
2835-
tensor([ 0.3223, 0.2815, -0.1944])
2836-
>>> torch.vdot(v1[0], v2[0])
2837-
tensor(0.3223)
2838-
""")

torch/overrides.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
923923
torch.ravel: lambda input: -1,
924924
torch.real: lambda input, out=None: -1,
925925
torch.vdot: lambda input, other, out=None: -1,
926-
torch.linalg.vecdot: lambda input, other, dim=-1, out=None: -1,
927926
torch.view_as_real: lambda input: -1,
928927
torch.view_as_complex: lambda input: -1,
929928
torch.reciprocal: lambda input, out=None: -1,

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,7 @@ class ReductionOpInfo(OpInfo):
13501350
the optional keyword parameters of the ReductionOpInfo constructor.
13511351

13521352
If a reduction operator does not yet implement the full required API of
1353-
reduction operators, this should be documented by xfailing the failing
1353+
reduction operators, this should be documented by skipping the failing
13541354
tests rather than adding optional parameters to ReductionOpInfo.
13551355

13561356
NOTE
@@ -3044,17 +3044,6 @@ def error_inputs_isclose(op, device, **kwargs):
30443044
error_regex='atol must be greater than or equal to zero')
30453045

30463046

3047-
def sample_inputs_linalg_vecdot(op_info, device, dtype, requires_grad, **kwargs):
3048-
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
3049-
batches = ((), (0,), (1,), (5,))
3050-
ns = (0, 1, 3, 5)
3051-
for b, n in product(batches, ns):
3052-
shape = b + (n,)
3053-
yield SampleInput(make_arg(shape), args=(make_arg(shape),))
3054-
for i in range(len(shape)):
3055-
yield SampleInput(make_arg(shape), args=(make_arg(shape),), kwargs=dict(dim=i))
3056-
3057-
30583047
def sample_inputs_t(op_info, device, dtype, requires_grad, **kwargs):
30593048
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
30603049
return (SampleInput(make_arg((1, 2))),
@@ -12443,15 +12432,6 @@ def error_inputs_mean(op_info, device, **kwargs):
1244312432
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
1244412433
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
1244512434
),
12446-
OpInfo('linalg.vecdot',
12447-
aten_name='linalg_vecdot',
12448-
ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
12449-
dtypes=floating_and_complex_types_and(torch.bfloat16),
12450-
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12451-
sample_inputs_func=sample_inputs_linalg_vecdot,
12452-
check_batched_forward_grad=False,
12453-
supports_forward_ad=True,
12454-
supports_fwgrad_bwgrad=True),
1245512435
OpInfo('linalg.cond',
1245612436
aten_name='linalg_cond',
1245712437
dtypes=floating_and_complex_types(),

0 commit comments

Comments
 (0)
0