8000 Update on "[Array API] Add linalg.vecdot" · pytorch/pytorch@c26e36c · GitHub
[go: up one dir, main page]

Skip to content

Commit c26e36c

Browse files
committed
Update on "[Array API] Add linalg.vecdot"
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this. Resolves #18027. cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi [ghstack-poisoned]
2 parents dd97da7 + abb921a commit c26e36c

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/testing/_internal/common_methods_invocations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12416,7 +12416,7 @@ def error_inputs_mean(op_info, device, **kwargs):
1241612416
aten_name='linalg_vecdot',
1241712417
ref=lambda x, y, *, dim=-1: (x.conj() * y).sum(dim),
1241812418
dtypes=floating_and_complex_types_and(torch.bfloat16),
12419-
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, torch.bfloat16),
12419+
dtypesIfCUDA=floating_and_complex_types_and(torch.float16, *[torch.bfloat16] if CUDA11OrLater else []),
1242012420
sample_inputs_func=sample_inputs_linalg_vecdot,
1242112421
check_batched_forward_grad=False,
1242212422
supports_forward_ad=True,

0 commit comments

Comments
 (0)
0