8000 Use float data type for Half sum in fallback implementation of batchnorm backward on CPU by CaoE · Pull Request #147353 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Use float data type for Half sum in fallback implementation of batchnorm backward on CPU #147353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

CaoE
Copy link
Collaborator
@CaoE CaoE commented Feb 18, 2025

Fixes #147303.
Use float data type for Half sum in fallback implementation of batchnorm backward on CPU as the representation range of Half is small.

Copy link
pytorch-bot bot commented Feb 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147353

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 746a203 with merge base e8b20f6 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: nn release notes category label Feb 18, 2025
@CaoE CaoE added ciflow/trunk Trigger trunk jobs on your pull request ciflow/inductor labels Feb 18, 2025
Copy link
Collaborator
@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write the PR description.

@CaoE CaoE requested a review from mingfeima February 18, 2025 06:42
@CaoE CaoE marked this pull request as ready for review February 18, 2025 06:43
@CaoE CaoE requested a review from cpuhrsch February 18, 2025 06:43
test/test_nn.py Outdated
for bwd_format in [torch.contiguous_format, torch.channels_last]:
helper(self, nn.BatchNorm2d, (16, 3, 224, 224), torch.float, fwd_format, bwd_format)

for fwd_format in [torch.contiguous_format, torch.channels_last_3d]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could also write this as

formats = [torch.contiguous_format, torch.channels_last_3d]
for (fwd_format, bwd_format) in itertools.product(formats, formats):
    helper(...)

which should be easier to read.

See the following to illustrate

>>> choices = [0, 1]
>>> list(itertools.product(choices, choices))
[(0, 0), (0, 1), (1, 0), (1, 1)]

or use parametrize as in the other tests to be able to get one test for each combo.

This should make it easier to catch any future failures.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. Revised.

auto sum = grad_out_.scalar_type() == kHalf
? at::sum(grad_out_.to(ScalarType::Float), /*dim=*/reduce_dims)
: at::sum(grad_out_, /*dim=*/reduce_dims);
using sum_t = std::conditional_t<std::is_same_v<scalar_t, at::Half>, float, scalar_t>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a nit, but would it be useful to use accscalar_t from above?

Copy link
Collaborator Author
@CaoE CaoE Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments. Do you mean using sum_t = std::conditional_t<std::is_same_v<scalar_t, at::Half>, accscalar_t , scalar_t>; ? We may not be able to use accscalar_t instead of sum_t as I only convert the input of the half type sum into float calculation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant accscalar_t instead of sum_t since for half it should map to float: https://github.com/pytorch/pytorch/blob/af3164039158f38ebe7ff17300c0307ecb0abcd6/aten/src/ATen/AccumulateType.h

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will make other types, such as accscalar_t = float for scalar_t = BFloat16 and accscalar_t = double for float. This PR is currently only intended to fix the half overflow case. Do you mean to make other types also use higher precision for sum?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CaoE - oh, I'd still guard on Half, but replace float with accscalar_t to create an explicit connection / prevent issues under any changes. In any case, it might be worthwhile to check the other types as well if there are numerical stability issues for Half. But this can be addressed in a follow up PR.

auto sum_a = sum.accessor<scalar_t, 1>();
// Using float data type for Half sum to avoid overflow
// since the representation range of Half is small.
auto sum = grad_out_.scalar_type() == kHalf
Copy link
Contributor
@cpuhrsch cpuhrsch Feb 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also very nit: I think an if/else here might be more readable than a ternary expression

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I keep the ternary expression? If use if else, the code will become longer like:

Tensor sum;
if (grad_out_.scalar_type() == kHalf) {
  sum = at::sum(grad_out_.to(ScalarType::Float), /*dim=*/reduce_dims);
} else {
  sum = at::sum(grad_out_, /*dim=*/reduce_dims);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, either one should work. I personally don't mind more verbose / longer code if it's more readable and easier to maintain.

@cpuhrsch
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Feb 24, 2025
…orm backward on CPU (#147353)

Fixes #147303.
Use float data type for Half sum in fallback implementation of batchnorm backward on CPU as the representation range of Half is small.

Pull Request resolved: #147353
Approved by: https://github.com/leslie-fang-intel, https://github.com/cpuhrsch
majing921201 pushed a commit to majing921201/pytorch that referenced this pull request Mar 4, 2025
…orm backward on CPU (pytorch#147353)

Fixes pytorch#147303.
Use float data type for Half sum in fallback implementation of batchnorm backward on CPU as the representation range of Half is small.

Pull Request resolved: pytorch#147353
Approved by: https://github.com/leslie-fang-intel, https://github.com/cpuhrsch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: nn release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

fp16 channels_last created Nan in batchnorm backward
5 participants
0