-
Notifications
You must be signed in to change notification settings - Fork 24.7k
Add _foreach_add_(TensorList tensors, Scalar scalar) API #42531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 7e18bb0 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 55 times. |
[ghstack-poisoned]
**Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API **Tests** Tested via unit tests **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
**Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API **Tests** Tested via unit tests **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
**Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
**Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you briefly describe why changes to autograd and jit are necessary?
@@ -5450,6 +5450,13 @@ | |||
CPU: foreach_add_scalar_kernel_fallback | |||
CUDA: foreach_tensor_add_scalar_kernel_cuda | |||
|
|||
- func: _foreach_add_.Scalar(Tensor[](a!) self, Scalar scalar) -> Tensor[](a!) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't have in-place functions, only methods. Is it feasible to make this an out-variant function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually this isn't quite true, a bunch of stuff in nn::functional is defined this way (although they aren't typically directly accessible).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this as well, but if i change it to method and try building this, i get
"RuntimeError: Function 'foreach_add' starts with a single underscore and is configured to have a method on Tensor. Functions that start with a single underscore should only be functions in the at:: namespace and not methods on Tensor!"
Which made me follow other examples in native_functions.yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good then!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't Tensor[](a!)
mean that the list is being mutated, and Tensor(a!)[]
be a list of mutable tensors?
@suo is that right? Is this safe to add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Coming in here late, I noticed this problem too while working on the new codegen at #42629 I also agree that Tensor(a!)[]
is the correct type, and what was landed (Tensor[](a!)
) is not correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added a small fix for the problem in the new codegen. Besides
- - func: _foreach_add_.Scalar(Tensor[](a!) self, Scalar scalar) -> ()
- device_guard: False
+ - func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
+ device_guard: False
I also had to add a special case to make void return for trailing underscore ok:
if self.name.name.inplace:
# TODO: fixme
if str(self.name) not in [
'_amp_non_finite_check_and_unscale_',
'_foreach_add_.Scalar']:
assert len(self.returns) == 1
… API" **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
… API" **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
… API" **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
… API" **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. [ghstack-poisoned]
… API" [First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. --------------- **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. [ghstack-poisoned]
@gchanan, @ngimel, @cpuhrsch
|
[First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. --------------- **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. [ghstack-poisoned]
[First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. --------------- **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. [ghstack-poisoned]
[First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. --------------- **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. Differential Revision: [D23331892](https://our.internmc.facebook.com/intern/diff/D23331892) [ghstack-poisoned]
[First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](#41554). **Motivation** [GitHub issue](#38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **Current API restrictions** - List can't be empty (will fixed in upcoming PRs). - All tensors in the list must have the same dtype, device and size. **Broadcasting** At this point we don't support broadcasting. **What is 'Fast' and 'Slow' route** In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path. To go the fast route, - All tensors must have strided layout - All tensors must be dense and not have overlapping memory - The resulting tensor type must be the same. --------------- **In this PR** - Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API - Resolving some additional comments from previous [PR](#41554). **Tests** Tested via unit tests **TODO** 1. Properly handle empty lists **Plan for the next PRs** 1. APIs - Binary Ops for list with Scalar - Binary Ops for list with list - Unary Ops for list - Pointwise Ops 2. Complete tasks from TODO 3. Rewrite PyTorch optimizers to use for-each operators for performance gains. Differential Revision: [D23331892](https://our.internmc.facebook.com/intern/diff/D23331892) [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/izdeby/35/base #42531 +/- ##
====================================================
Coverage ? 69.34%
====================================================
Files ? 378
Lines ? 46674
Branches ? 0
====================================================
Hits ? 32364
Misses ? 14310
Partials ? 0 Continue to review full report at Codecov.
|
Stack from ghstack:
First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar).
Motivation
GitHub issue
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start.
As an example, we should be looking at NVIDIAs Apex.
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.
Current API restrictions
Broadcasting
At this point we don't support broadcasting.
What is 'Fast' and 'Slow' route
In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path.
To go the fast route,
In this PR
std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)
APITests
Tested via unit tests
TODO
Plan for the next PRs
Differential Revision: D23331892