10000 Turn indexing with a scalar tensor into an copy into a view and avoid a D2H synchronization. · Issue #105641 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Turn indexing with a scalar tensor into an copy into a view and avoid a D2H synchronization. #105641
@Chillee

Description

@Chillee

🚀 The feature, motivation and pitch

Today, this triggers a cuda synchronization:

import torch

torch.set_default_device('cuda')

def f(x, y):
    return x[y]

inps = (torch.randn(5), torch.tensor(0))
torch.cuda.set_sync_debug_mode(2)
f(*inps)

The reason why is that when the tensor is a 0-dim value, instead of launching a gather kernel, we move the tensor to the hsot and do a slice instead (https://github.com/pytorch/pytorch/pull/105518/files#diff-2574bfb0ffa78d685fb7bd2ebc0c64b1a5f87dd55ec74ae67b41b31adc566020L466).

We could just fix this, but unfortunately, this does change the semantics. In particular, now, this operation would create a copy instead of a view, which could cause issues for downstream in-place operations.

I think these are bad semantics, for 3 reasons:

  1. Cuda synchronizations are very bad in general. They're slow, prevent the use of many different features (streams, cudagraphs, don't play well with collectives, etc.), and should strongly be avoided. This, however, is a very implicit coercion we're doing. It's not obvious at all that if the tensor is 3-dim/2-dim/1-dim it doesn't do a sync, but if the tensor is 0-dim it does do a sync. In addition, it makes this much harder to trace and compile/not composite compliant in general.

  2. Moreover, it's not consistent!!

Why should x[torch.tensor(0)] return a view but x[torch.tensor([0]) return a copy? Why should the first one do a synchronization and the second one doesn't?

To drive this point home further, we also diverge from Numpy semantics here.

import numpy as np
x = np.ones(5)
y = np.array(1)
z = x[y]
z += 1
print(x)
>>> array([1., 1., 1., 1., 1.])
  1. It's actually slower than just doing the index operator on GPUs! Benchmarking x[torch.tensor(0)] vs. x[torch.tensor([0]), we see that the first takes 35 us per iteration while the second one takes 8 us.

PS: I've also done a brief survey of use cases with this pattern I could find (https://github.com/search?q=%2F%28%5Cw%2B%29%5C%5Btorch.tensor%5C%28%2F+language%3APython&type=code), and I couldn't find many use cases of this code path at all.

cc: @ezyang @zou3519 @ngimel

cc @ezyang @gchanan @mruberry @rgommers

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: advanced indexingRelated to x[i] = y, index functionsmodule: bc-breakingRelated to a BC-breaking changemodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstopic: bc breakingtopic categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0