-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Description
🚀 The feature, motivation and pitch
Today, this triggers a cuda synchronization:
import torch
torch.set_default_device('cuda')
def f(x, y):
return x[y]
inps = (torch.randn(5), torch.tensor(0))
torch.cuda.set_sync_debug_mode(2)
f(*inps)
The reason why is that when the tensor is a 0-dim value, instead of launching a gather kernel, we move the tensor to the hsot and do a slice instead (https://github.com/pytorch/pytorch/pull/105518/files#diff-2574bfb0ffa78d685fb7bd2ebc0c64b1a5f87dd55ec74ae67b41b31adc566020L466).
We could just fix this, but unfortunately, this does change the semantics. In particular, now, this operation would create a copy instead of a view, which could cause issues for downstream in-place operations.
I think these are bad semantics, for 3 reasons:
-
Cuda synchronizations are very bad in general. They're slow, prevent the use of many different features (streams, cudagraphs, don't play well with collectives, etc.), and should strongly be avoided. This, however, is a very implicit coercion we're doing. It's not obvious at all that if the tensor is 3-dim/2-dim/1-dim it doesn't do a sync, but if the tensor is 0-dim it does do a sync. In addition, it makes this much harder to trace and compile/not composite compliant in general.
-
Moreover, it's not consistent!!
Why should x[torch.tensor(0)]
return a view but x[torch.tensor([0])
return a copy? Why should the first one do a synchronization and the second one doesn't?
To drive this point home further, we also diverge from Numpy semantics here.
import numpy as np
x = np.ones(5)
y = np.array(1)
z = x[y]
z += 1
print(x)
>>> array([1., 1., 1., 1., 1.])
- It's actually slower than just doing the index operator on GPUs! Benchmarking
x[torch.tensor(0)]
vs.x[torch.tensor([0])
, we see that the first takes35 us
per iteration while the second one takes8 us
.
PS: I've also done a brief survey of use cases with this pattern I could find (https://github.com/search?q=%2F%28%5Cw%2B%29%5C%5Btorch.tensor%5C%28%2F+language%3APython&type=code), and I couldn't find many use cases of this code path at all.