8000 Unnecessary cuda synchronizations that we should remove in PyTorch · Issue #108968 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Unnecessary cuda synchronizations that we should remove in PyTorch #108968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
2 of 6 tasks
Chillee opened this issue Sep 10, 2023 · 6 comments
Open
2 of 6 tasks

Unnecessary cuda synchronizations that we should remove in PyTorch #108968

Chillee opened this issue Sep 10, 2023 · 6 comments
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Chillee
Copy link
Collaborator
Chillee commented Sep 10, 2023

🚀 The feature, motivation and pitch

There are a number of unnecessary cuda synchronizations in PyTorch ops, and I think we should endeavor to remove them whenever possible.
To check syncs, you can use torch.cuda.set_sync_debug_mode("warn")

I'm creating this issue to track ones that I've seen/found.

A = torch.rand(10)
torch.multinomial(A, num_samples=1)
  • repeat_interleave with a tensor number of repeats encourages synchronization. We cannot use repeats with a non-cuda tensor, and that forces a synchronization. For this I think we should add a list of ints overload or allow passing a CPU tensor for repeats.
A = torch.randn(3, device='cuda')
num_repeats = torch.tensor([2, 3, 5])
out = torch.repeat_interleave(A, num_repeats.cuda(), dim=0)

Alternatives

No response

Additional context

No response

cc @ptrblck

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Sep 10, 2023

@Chillee maybe also that's why part of why repeat_interleave is slow: #31980, also a bit related: #73175

@lezcano
Copy link
Collaborator
lezcano commented Sep 10, 2023

On point 2, see data-apis/array-api#654. The array API will have repeats be a tuple.

@vadimkantorov
Copy link
Contributor
vadimkantorov commented Sep 10, 2023

I hope that tensor/array can also be accepted where tuple is expected - this might be important for PyTorch where we can have some minor improvements if the counts is already a tensor (obtained from some computation with arrays) and is on the target device

@Chillee
Copy link
Collaborator Author
Chillee commented Sep 10, 2023

@vadimkantorov repeat_interleave currently takes in a tensor but not tuples. In general, it's not always a good to take in tensors where we currently take tuples, since they might require device to host synchronizations anyways (for example, view).

@drisspg drisspg added module: performance Issues related to performance, either of kernel code or framework glue module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 11, 2023
@kgryte
Copy link
kgryte commented Sep 21, 2023

The array API will have repeats be a tuple.

@lezcano Currently, most array libraries support a one-dimensional array for repeats. For that proposal, based on recent feedback, I think we're leaning toward supporting both sequences and arrays.

@DeNeutoy
Copy link
Contributor
DeNeutoy commented Oct 20, 2023

+1 to this issue - this is causing some non-trivial performance issues for us:

Screen Shot 2023-10-19 at 7 39 23 PM

For context, this repeat interleave is very large (e.g the repeats tensor may be of size ~128, with values up to 4000).

I don't think this is causing slowdowns exactly, but we do train multigpu models, where random syncing inside cuda ops is presumably more of a problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda Related to torch.cuda, and CUDA support in general module: performance Issues related to performance, either of kernel code or framework glue triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants
0