-
Notifications
You must be signed in to change notification settings - Fork 24.3k
Use float data type for Half sum in fallback implementation of batchnorm backward on CPU #147353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147353
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 746a203 with merge base e8b20f6 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please write the PR description.
test/test_nn.py
Outdated
for bwd_format in [torch.contiguous_format, torch.channels_last]: | ||
helper(self, nn.BatchNorm2d, (16, 3, 224, 224), torch.float, fwd_format, bwd_format) | ||
|
||
for fwd_format in [torch.contiguous_format, torch.channels_last_3d]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you could also write this as
formats = [torch.contiguous_format, torch.channels_last_3d]
for (fwd_format, bwd_format) in itertools.product(formats, formats):
helper(...)
which should be easier to read.
See the following to illustrate
>>> choices = [0, 1]
>>> list(itertools.product(choices, choices))
[(0, 0), (0, 1), (1, 0), (1, 1)]
or use parametrize
as in the other tests to be able to get one test for each combo.
This should make it easier to catch any future failures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments. Revised.
auto sum = grad_out_.scalar_type() == kHalf | ||
? at::sum(grad_out_.to(ScalarType::Float), /*dim=*/reduce_dims) | ||
: at::sum(grad_out_, /*dim=*/reduce_dims); | ||
using sum_t = std::conditional_t<std::is_same_v<scalar_t, at::Half>, float, scalar_t>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a nit, but would it be useful to use accscalar_t
from above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments. Do you mean using sum_t = std::conditional_t<std::is_same_v<scalar_t, at::Half>, accscalar_t , scalar_t>;
? We may not be able to use accscalar_t
instead of sum_t
as I only convert the input of the half type sum into float calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant accscalar_t instead of sum_t since for half it should map to float: https://github.com/pytorch/pytorch/blob/af3164039158f38ebe7ff17300c0307ecb0abcd6/aten/src/ATen/AccumulateType.h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will make other types, such as accscalar_t = float for scalar_t = BFloat16 and accscalar_t = double for float. This PR is currently only intended to fix the half overflow case. Do you mean to make other types also use higher precision for sum?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@CaoE - oh, I'd still guard on Half, but replace float
with accscalar_t
to create an explicit connection / prevent issues under any changes. In any case, it might be worthwhile to check the other types as well if there are numerical stability issues for Half. But this can be addressed in a follow up PR.
auto sum_a = sum.accessor<scalar_t, 1>(); | ||
// Using float data type for Half sum to avoid overflow | ||
// since the representation range of Half is small. | ||
auto sum = grad_out_.scalar_type() == kHalf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also very nit: I think an if/else here might be more readable than a ternary expression
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can I keep the ternary expression? If use if else, the code will become longer like:
Tensor sum;
if (grad_out_.scalar_type() == kHalf) {
sum = at::sum(grad_out_.to(ScalarType::Float), /*dim=*/reduce_dims);
} else {
sum = at::sum(grad_out_, /*dim=*/reduce_dims);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, either one should work. I personally don't mind more verbose / longer code if it's more readable and easier to maintain.
…orm backward on CPU
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…orm backward on CPU (#147353) Fixes #147303. Use float data type for Half sum in fallback implementation of batchnorm backward on CPU as the representation range of Half is small. Pull Request resolved: #147353 Approved by: https://github.com/leslie-fang-intel, https://github.com/cpuhrsch
…orm backward on CPU (pytorch#147353) Fixes pytorch#147303. Use float data type for Half sum in fallback implementation of batchnorm backward on CPU as the representation range of Half is small. Pull Request resolved: pytorch#147353 Approved by: https://github.com/leslie-fang-intel, https://github.com/cpuhrsch
Fixes #147303.
Use float data type for Half sum in fallback implementation of batchnorm backward on CPU as the representation range of Half is small.