-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Add _foreach_add(TensorList tl1, TensorList tl2) and _foreach_add_(TensorList tl1, TensorList tl2) APIs #42533
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 729d53f (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
…and _foreach_add_(TensorList tl1, TensorList tl2) APIs" [ghstack-poisoned]
…ach_add_(TensorList tl1, TensorList tl2) APIs" **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **In this PR** - Adding a `_foreach_add(TensorList tl1, TensorList tl2)` API - Adding a `_foreach_add_(TensorList tl1, TensorList tl2)` API **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
which optimizer needs this? |
out of place list add e.g. sgd |
broadcasting? |
|
torch._foreach_add_(tensors1, tensors2) | ||
self.assertEqual(res, tensors1) | ||
self.assertEqual(res[0], torch.ones(10, 10, device=device, dtype=dtype)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to test some edge cases (e.g., number of tensors in the list >= 100).
aten/src/ATen/native/ForeachUtils.h
Outdated
|
||
auto expected_dtype = tensors1[0].dtype(); | ||
|
||
for (int i = 0; i < tensors1.size(); i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR body implies that "All tensors in the list must have the same [device]". Should that be checked here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have this check in check_fast_route, as we dont want to fail the op if tensors are on different devices, but just go via regular route.
std::vector<Tensor> foreach_tensor_add_list_kernel_cuda(TensorList tensors1, TensorList tensors2) { | ||
verify_list(tensors1, tensors2); | ||
|
||
if (!check_fast_route(tensors1, tensors2)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: As a function name, "check_fast_route" doesn't sound like it would return a boolean (it sounds like it would return nothing). Something like "can_use_fast_route" sounds like it would return a boolean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agree. renamed.
return at::native::foreach_tensor_add_list_kernel_slow(tensors1, tensors2); | ||
} | ||
|
||
std::vector<std::vector<at::Tensor>> tensor_lists; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(no action required) We know exactly how many lists there needs to be, so it would be nice if we could avoid the dynamic allocation here. This would require us to revisit the first argument to multi_tensor_apply
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added to TODO as a follow up
} | ||
} | ||
else { | ||
// Non-divergent exit condition for __syncthreads, not necessary here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the __syncthreads
this is referring to? What exactly is "not necessary here"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
irrelevant in our case. removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
n -= chunk_idx * chunk_size; | ||
|
||
T r_x[kILP]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is kILP? It looks like it is equal to 4, but I'm not sure what that means
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instruction-level parallelism.
Please, note that this code is 95% taken from APEX so most of comments/constants are coming from there.
} | ||
#pragma unroll | ||
for(int ii = 0; ii < kILP; ii++) { | ||
r_out[ii] = static_cast<T>(r_x[ii]) + static_cast<T>(r_y[ii]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the important part of the Functor is r_out[ii] = static_cast<T>(r_x[ii]) + static_cast<T>(r_y[ii]);
and that the other logic is pretty standard between Functors (from me comparing AddListFunctor and AddListFunctor_). Do we ever want functors that don't have similar logic? If not, it might be worth trying to abstract away the rest of the code via templates or macros so that someone implementing a new Functor doesn't have to copy & paste code that is considered to be boilerplate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(to clarify, no action necessary, but I am curious about if we can clean this up in the future)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, i thought about this. I had a hard time figuring out what all the functors will look like so i decided that i would add them all as is and will refactor later once i have all the requirements and corner cases. But im open for discussion here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also not sure what the functors will look like in the end state so I agree it makes sense to refactor later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to go read through how multi_tensor_apply and functors work, but the code looks like it is following in the footsteps of the last PRs so I don't have any major comments. Added some minor comments about the testing and some questions in-line
…ach_add_(TensorList tl1, TensorList tl2) APIs" [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. ---------------- **In this PR** - Adding a `_foreach_add(TensorList tl1, TensorList tl2)` API - Adding a `_foreach_add_(TensorList tl1, TensorList tl2)` API **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. Differential Revision: [D23331894](https://our.internmc.facebook.com/intern/diff/D23331894) [ghstack-poisoned]
std::vector<Tensor> foreach_tensor_add_list_kernel_slow(TensorList tensors1, TensorList tensors2) { | ||
verify_list(tensors1, tensors2); | ||
|
||
std::vector<Tensor> result; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: result.reserve(tensors1.size())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
#pragma unroll | ||
for(int ii = 0; ii < kILP; ii++) { | ||
r_out[ii] = static_cast<T>(r_x[ii]) + static_cast<T>(r_y[ii]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also not sure what the functors will look like in the end state so I agree it makes sense to refactor later
aten/src/ATen/native/ForeachUtils.h
Outdated
TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); | ||
} | ||
} | ||
|
||
// To go via 'fast' path, several conditions must be satisfied | ||
void verify_list(TensorList tensors1, TensorList tensors2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A comment here that these are the restrictions for the foreach TensorList APIs would be nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
test/test_foreach.py
Outdated
# different devices | ||
tensor1 = torch.zeros(10, 10, device="cuda:0") | ||
tensor2 = torch.ones(10, 10, device="cuda:1") | ||
with self.assertRaisesRegex(RuntimeError, "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to just check part of the string, e.g. "Expected all tensors to be on the same device". The goal of checking assertRaisesRegex is to make sure the error message isn't something completely unreadable like "TORCH_INTERNAL_ASSERT(false)".
test/test_foreach.py
Outdated
# Coresponding tensors with different sizes | ||
tensors1 = [torch.zeros(10, 10, device=device) for _ in range(10)] | ||
tensors2 = [torch.ones(11, 11, device=device) for _ in range(10)] | ||
with self.assertRaisesRegex(RuntimeError, "Corresponding tensors in lists must have the same size, got \[10, 10\] and \[11, 11\]"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'\['
isn't a valid escape sequence for regular strings, you have to append an r to the front of the string:
r"Corresponding tensors in lists..."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
aten/src/ATen/native/ForeachUtils.h
Outdated
TORCH_CHECK(t.dtype() == expected_dtype, "All tensors in the tensor list must have the same dtype."); | ||
} | ||
} | ||
10000
|
||
// To go via 'fast' path, several conditions must be satisfied | ||
void verify_list(TensorList tensors1, TensorList tensors2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this is in the at::native namespace. To avoid name collision, it might be nice to name this something more detailed, like "check_foreach_api_restrictions"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
…ach_add_(TensorList tl1, TensorList tl2) APIs" [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. ---------------- **In this PR** - Adding a `_foreach_add(TensorList tl1, TensorList tl2)` API - Adding a `_foreach_add_(TensorList tl1, TensorList tl2)` API **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. Differential Revision: [D23331894](https://our.internmc.facebook.com/intern/diff/D23331894) [ghstack-poisoned]
…ach_add_(TensorList tl1, TensorList tl2) APIs" [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. ---------------- **In this PR** - Adding a `_foreach_add(TensorList tl1, TensorList tl2)` API - Adding a `_foreach_add_(TensorList tl1, TensorList tl2)` API **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. Differential Revision: [D23331894](https://our.internmc.facebook.com/intern/diff/D23331894) [ghstack-poisoned]
Stack from ghstack:
First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar).
Motivation
GitHub issue
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start.
As an example, we should be looking at NVIDIAs Apex.
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.
Current API restrictions
Broadcasting
At this point we don't support broadcasting.
What is 'Fast' and 'Slow' route
In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path.
To go the fast route,
In this PR
_foreach_add(TensorList tl1, TensorList tl2)
API_foreach_add_(TensorList tl1, TensorList tl2)
APITests
Tested via unit tests
TODO
Plan for the next PRs
Differential Revision: D23331894