8000 Update on "[Array API] Add linalg.vecdot" · pytorch/pytorch@c926457 · GitHub
[go: up one dir, main page]

Skip to content

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit c926457

Browse files
committed
Update on "[Array API] Add linalg.vecdot"
This PR adds the function `linalg.vecdot` specified by the [Array API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot) For the complex case, it chooses to implement \sum x_i y_i. See the discussion in data-apis/array-api#356 Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this. cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi [ghstack-poisoned]
1 parent 5cbd57d commit c926457

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

torch/testing/_internal/common_methods_invocations.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ def sample_inputs_reduction(op_info, device, dtype, requires_grad, **kwargs):
968968
supports_multiple_dims: bool = kwargs.get('supports_multiple_dims', True)
969969

970970
# TODO(@heitorschueroff) Once all reduction operators are using ReductionOpInfo
971-
# use op_info.genearte_args_kwargs directly.
971+
# use op_info.generate_args_kwargs directly.
972972
generate_args_kwargs = kwargs.get('generate_args_kwargs', lambda *args, **kwargs: (yield tuple(), {}))
973973

974974
inputs: List[SampleInput] = []
@@ -1101,7 +1101,7 @@ class ReductionOpInfo(OpInfo):
11011101
the optional keyword parameters of the ReductionOpInfo constructor.
11021102

11031103
If a reduction operator does not yet implement the full required API of
1104-
reduction operators, this should be documented by skipping the failing
1104+
reduction operators, this should be documented by xfailing the failing
11051105
tests rather than adding optional parameters to ReductionOpInfo.
11061106

11071107
NOTE
@@ -9787,17 +9787,16 @@ def ref_pairwise_distance(input1, input2):
97879787
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
97889788
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
97899789
),
9790-
BinaryUfuncInfo('linalg.vecdot',
9791-
aten_name='linalg_vecdot',
9792-
ref=lambda x, y, *, dim=-1: (x * y).sum(dim),
9793-
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
9794-
sample_inputs_func=sample_inputs_linalg_vecdot,
9795-
supports_forward_ad=True,
9796-
supports_fwgrad_bwgrad=True,
9797-
skips=(
9798-
# torch.sum(out=) has an incorrect behaviour
9799-
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),
9800-
),),
9790+
OpInfo('linalg.vecdot',
9791+
aten_name='linalg_vecdot',
9792+
ref=lambda x, y, *, dim=-1: (x * y).sum(dim),
9793+
dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
9794+
sample_inputs_func=sample_inputs_linalg_vecdot,
9795+
supports_forward_ad=True,
9796+
supports_fwgrad_bwgrad=True,
9797+
skips=(
9798+
# FIXME torch.sum(out=) has an incorrect behaviour
9799+
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out'),),),
98019800
OpInfo('linalg.cond',
98029801
aten_name='linalg_cond',
98039802
dtypes=floating_and_complex_types(),

0 commit comments

Comments
 (0)
0