8000 Support Batch Dot Product 路 Issue #18027 路 pytorch/pytorch 路 GitHub
[go: up one dir, main page]

Skip to content

Support Batch Dot Product #18027

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
sidazhang opened this issue Mar 14, 2019 · 16 comments
Closed

Support Batch Dot Product #18027

sidazhang opened this issue Mar 14, 2019 · 16 comments
Labels
feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@sidazhang
Copy link

馃殌 Feature

Support batch dot product

Motivation

Commonly used operation

Alternatives

Currently, we can do this with bmm. But I think it makes sense to just support batch dot product since it is very commonly used

@sidazhang
Copy link
Author

@suo I can implement this if you can point me to where this code should go

@soumith
Copy link
Member
soumith commented Mar 14, 2019

isn't that just a matrix multiply?

@ssnl
Copy link
Collaborator
ssnl commented Mar 14, 2019

@soumith it's the diagonal of a mm

@sidazhang
Copy link
Author

Yeah it is just a matrix multiply. But it is annoying to write everytime. It is so commonly used I think we should just have a batch dot method.

def bdot(a, b):
    B = a.shape[0]
    S = a.shape[1]
    return torch.bmm(a.view(B, 1, S), b.view(B, S, 1)).reshape(-1)

@vishwakftw vishwakftw added the feature A request for a proper, new feature. label Mar 15, 2019
@zou3519
Copy link
Contributor
zou3519 commented Mar 15, 2019

@sidazhang what are the sizes of a and b? (B, S) and (B, S), and the output is of size (B,) ?

If so, I think einsum might be what you're looking for:

import torch
B = 3
S = 5

x = torch.randn(B, S)
y = torch.randn(B, S)
out = torch.einsum('bs,bs->b', x, y)

@ssnl
Copy link
Collaborator
ssnl commented Mar 15, 2019

So, afaik, neither bmm nor einsum is efficient for this case. But maybe that has changed from when I heard about this.

@ssnl
Copy link
Collaborator
ssnl commented Mar 15, 2019

@sidazhang what I would suggest is to benchmark for your usecase

  1. bmm
  2. einsum
  3. mul+sum

and just write a helper function using the best one. It's simple enough that I don't know if we should provide in core. Numpy doesn't have a bdot method as well.

@samuelbroscheit
Copy link
samuelbroscheit commented Mar 15, 2019
import torch
B = 1024
S = 128
鈥
x = torch.randn(B, S)
y = torch.randn(B, S)

%%timeit 
鈥媜ut = torch.einsum('bs,bs->b', x, y)
57.2 碌s 卤 4.43 碌s per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)

%%timeit 
out = torch.bmm(x.view(B, 1, S), y.view(B, S, 1)).reshape(-1)
41.8 碌s 卤 1.41 碌s per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)

%%timeit 
out = (x*y).sum(-1)
30.4 碌s 卤 440 ns per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)
import torch
B = 1024
S = 128
鈥
x = torch.randn(B, S).cuda()
y = torch.randn(B, S).cuda()

%%timeit 
out = torch.einsum('bs,bs->b', x, y)
torch.cuda.synchronize()
240 碌s 卤 4.11 碌s per loop (mean 卤 std. dev. of 7 runs, 1000 loops each)

%%timeit 
out = torch.bmm(x.view(B, 1, S), y.view(B, S, 1)).reshape(-1)
torch.cuda.synchronize()
225 碌s 卤 337 ns per loop (mean 卤 std. dev. of 7 runs, 1000 loops each)

%%timeit 
out = (x*y).sum(-1)
torch.cuda.synchronize()
30.7 碌s 卤 395 ns per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)


torch.version
'1.0.1.post2'

Maybe -- even if trivial -- a dedicated bdot would be nice, such that users automatically use the most efficient option?

@zou3519
Copy link
Contributor
zou3519 commented Mar 15, 2019

You need to add a synchronise when using timeit with cuda functions. Something like:

def fn():
    out = torch.einsum('bs,bs->b', x, y)
    torch.cuda.synchronize()

%timeit fn()

@samuelbroscheit
Copy link

Thanks, I updated the timings.

@soumith
Copy link
Member
soumith commented Mar 29, 2019

it looks like the mul + sum is always cheaper, and sometimes significantly so

@sidazhang
Copy link
Author

it looks like the mul + sum is always cheaper, and sometimes significantly so

Hence, how do you feel about provding a wrapper for users?

@fmassa
Copy link
Member
fmassa commented Apr 1, 2019

I'm not sure we should be making torch.dot support batched tensors.

If we add it, it should preferably follow numpy semantics of np.dot.

But np.dot is tagged to be deprecated numpy/numpy#5859, and numpy developers have stated that they regret the current semantics of np.dot

matmul's broadcasting is much more general, and in my opinion, also easier to understand. For example, it can do batch matrix-multiplication, but also can still do outer product style broadcasting if you insert dummy dimensions of length 1 (the axes do end up in a different order), e.g.,
batch matmul: [p x q x r] matmul [p x r x t] -> [p x q x t]
outer product matmul: [p x 1 x q x r] matmul [1 x s x r x t] -> [p x s x q x t]

If we could go back in time as NumPy developers, we assuredly would change dot to work this way (now we cannot, because of backwards compatibility concerns)

So, between having a semantic difference between torch.dot and np.dot, or copying np.dot behavior or not adding batch support for torch.dot, I'd go with the last option.

By the way, here was a similar discussion in #138

@xidulu
Copy link
Contributor
xidulu commented Apr 17, 2019

it looks like the mul + sum is always cheaper, and sometimes significantly so

The mul + sum method would use more memory compared with the bmm method in my test cases.

import time

import torch
import torch.nn.functional as F
torch.backends.cudnn.enabled = True


batch_size = 128 * 1000 * 1000
Dimension = 16
use_cuda = True
# perceptron = (torch.nn.Linear(Dimension, 1))
input = torch.randn(batch_size, Dimension)
W = torch.randn(Dimension, Dimension)


start_time = time.time()

if use_cuda:
    # perceptron = perceptron.cuda()
    input = input.cuda()
    W = W.cuda()

load_time = time.time() - start_time
print("Loading time: {} secs".format(load_time))

start_time = time.time()
lhs = torch.matmul(input, W)
rhs = torch.t(input)

# This worked
result = torch.bmm(lhs.view(batch_size, 1, Dimension), input.view(batch_size, Dimension, 1)).reshape(-1)

# GPU got out of memory
# result = (lhs * input).sum(-1)

calc_time = time.time() - start_time
print("Calc time: {}".format(calc_time))

start_time = time.time()
result = result.cpu().data
print(result.shape)
copy_back_time = time.time() - start_time
print("Copy back time: {}".format(copy_back_time))


total_time = calc_time + copy_back_time + load_time
print("Total time: {} secs".format(total_time))
print("Avg throughput:{}".format(128 * 1000 * 1000 /  total_time))

@VitalyFedyunin VitalyFedyunin added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 17, 2019
@ssnl
Copy link
Collaborator
ssnl commented Jan 17, 2020

Some newer timings using the same setting as @samuelbroscheit, after warming up,

In [16]: %%timeit
    ...: out = torch.tensordot(x.T, y, dims=1)
    ...: torch.cuda.synchronize()
    ...:
    ...:
66.7 碌s 卤 655 ns per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)

In [17]: %%timeit
    ...: torch.einsum('bs,bs->b', x, y)
    ...: torch.cuda.synchronize()
    ...:
    ...:

70.4 碌s 卤 708 ns per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)

In [18]:

In [18]: %%timeit
    ...: out = torch.bmm(x.view(B, 1, S), y.view(B, S, 1)).reshape(-1)
    ...: torch.cuda.synchronize()
    ...:
    ...:
39.6 碌s 卤 1.11 碌s per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)

In [19]: %%timeit
    ...: out = (x*y).sum(-1)
    ...: torch.cuda.synchronize()
    ...:
    ...:

45.8 碌s 卤 734 ns per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)

In [20]:

In [21]: %%timeit
    ...: out = torch.tensordot(x, y.T, dims=1)
    ...: torch.cuda.synchronize()
    ...:
    ...:

86.3 碌s 卤 724 ns per loop (mean 卤 std. dev. of 7 runs, 10000 loops each)
PyTorch version: 1.4.0.dev20191205
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 16.04.6 LTS
GCC version: (Ubuntu 5.5.0-12ubuntu1~16.04) 5.5.0 20171010
CMake version: version 3.14.0

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: TITAN Xp
GPU 1: TITAN Xp
GPU 2: TITAN Xp
GPU 3: TITAN Xp

Nvidia driver version: 440.33.01
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.17.2
[pip] torch==1.4.0.dev20191205
[pip] torchfile==0.1.0
[pip] torchvision==0.5.0a0+7745e73
[conda] blas                      1.0                         mkl
[conda] mkl                       2019.4                      243
[conda] mkl-include               2019.4                      243
[conda] mkl-service               2.3.0            py37he904b0f_0
[conda] mkl_fft                   1.0.15           py37ha843d7b_0
[conda] mkl_random                1.1.0            py37hd6b4f25_0
[conda] pytorch                   1.4.0.dev20191205 py3.7_cuda10.1.243_cudnn7.6.3_0    pytorch-nightly
[conda] torch                     1.3.0                    pypi_0    pypi
[conda] torchfile                 0.1.0                    pypi_0    pypi
[conda] torchvision               0.5.0.dev20191205      py37_cu101    pytorch-nightly

@heitorschueroff
Copy link
Contributor

While working on torch.einsum here #46398 I came across the need for this as well. I added a case in my PR that reduces to torch.dot as it is significantly faster than torch.bmm, but with batched dot this optimization could apply to many more cases.

ngraymon added a commit to ngraymon/DNN-SE that referenced this issue Dec 7, 2021
for some reason the 2nd derivative is still a bust even if I 'cheat'
and normalize the wavefunction

for more information about optimal dot products/ matmuls see
https://discuss.pytorch.org/t/dot-product-batch-wise/9746/12
https://pytorch.org/docs/stable/generated/torch.mul.html
pytorch/pytorch#18027

specifically the github issue shows that torch.mul().sum() is way more
efficient
lezcano added a commit that referenced this issue Jan 3, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jan 5, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Mar 10, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Mar 10, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue May 18, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue May 18, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue May 19, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue May 19, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue May 19, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue May 19, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 6, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 6, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 6, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 6, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 11, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 11, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 11, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 11, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 12, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 12, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 12, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
lezcano added a commit that referenced this issue Jul 12, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this issue Jul 12, 2022
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc @mruberry @rgommers @pmeier @asmeurer @leofang @AnirudhDagar @asi1024 @emcastillo @kmaehashi
Pull Request resolved: #70542
Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry
facebook-github-bot pushed a commit that referenced this issue Jul 13, 2022
Summary:
This PR adds the function `linalg.vecdot` specified by the [Array
API](https://data-apis.org/array-api/latest/API_specification/linear_algebra_functions.html#function-vecdot)

For the complex case, it chooses to implement \sum x_i y_i. See the
discussion in data-apis/array-api#356

Edit. When it comes to testing, this function is not quite a binopt, nor a reduction opt. As such, we're this close to be able to get the extra testing, but we don't quite make it. Now, it's such a simple op that I think we'll make it without this.

Resolves #18027.

cc mruberry rgommers pmeier asmeurer leofang AnirudhDagar asi1024 emcastillo kmaehashi

Pull Request resolved: #70542
Approved by: https://github.com/IvanYashchuk, https://github.com/mruberry

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/e505796a2c0711e6e8ea39c12c19848dc49a60b5

Reviewed By: DanilBaibak

Differential Revision: D37813214

Pulled By: DanilBaibak

fbshipit-source-id: 782cb7cb31fda5dba4f3de70496f2470b41d8e34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

10 participants
0