10000 [Inductor] Expand dtype aware codegen for libdevice and tl.math ops by blaine-rister · Pull Request #140864 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Expand dtype aware codegen for libdevice and tl.math ops #140864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 63 commits into from

Conversation

blaine-rister
Copy link
Contributor
@blaine-rister blaine-rister commented Nov 16, 2024

Feature

Previously, only the codegen for torch.sqrt was dtype aware. This PR updates most of the libdevice/tl.math ops to support dtype-aware codegen as well. This is often necessary to get correct code when config.triton.codegen_upcast_to_fp32=False, as most Triton math ops do not support float16/bfloat16.

This PR enables dtype aware codegen via the maybe_upcast_float32 decorator. This wraps TritonOverrides macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set convert_output=False and skip the output cast.

Test Plan

Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the maybe_upcast_float32 decorator, and stored in the new OpDtypeSupport class. In each new test, we search the generated code for upcasts/downcasts using a regex.

Also added a unit test for OpDtypeSupport which checks that we have correct dtype info for ops that require upcasts.

This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

eellison and others added 9 commits November 6, 2024 16:27
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Nov 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140864

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit 9c5d202 with merge base ed77901 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

blaine-rister and others added 10 commits November 15, 2024 18:50
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule.

We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@blaine-rister your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Dec 6, 2024
…th ops (#140864)"

This reverts commit 80ca6dd.

Reverted #140864 on behalf of https://github.com/atalman due to failing internally ([comment](#140864 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Dec 6, 2024
@blaine-rister
Copy link
Contributor Author

@pytorchbot revert -m "Nondetermistic test is failing internally"

Copy link
pytorch-bot bot commented Dec 6, 2024

❌ 🤖 pytorchbot command failed:

@pytorchbot revert: error: the following arguments are required: -c/--classification

usage: @pytorchbot revert -m MESSAGE -c
                          {nosignal,ignoredsignal,landrace,weird,ghfirst}

Try @pytorchbot --help for more info.

@blaine-rister
Copy link
Contributor Author

@pytorchbot merge

Copy link
netlify bot commented Dec 6, 2024

Deploy Preview for chimerical-cranachan-793287 ready!

Name Link
🔨 Latest commit 9c5d202
🔍 Latest deploy log https://app.netlify.com/sites/chimerical-cranachan-793287/deploys/675368b3bb7011000897083c
😎 Deploy Preview https://deploy-preview-140864--chimerical-cranachan-793287.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

@blaine-rister has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator
6D40

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@blaine-rister
Copy link
Contributor Author

Canceling merge until I can confirm that internal tests are fixed.

pytorchmergebot added a commit that referenced this pull request Dec 7, 2024
…th ops (#140864)"

This reverts commit 80ca6dd.

Reverted #140864 on behalf of https://github.com/atalman due to failing internally ([comment](#140864 (comment)))
@blaine-rister
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 2 checks: pull / linux-focal-py3.13-clang10 / build, pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@blaine-rister
Copy link
Contributor Author

@pytocrhbot merge -i

@blaine-rister
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while i F438 gnoring the following 2 checks: pull / linux-focal-py3.13-clang10 / build, pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request Dec 9, 2024
…140864)

# Feature
Previously, only the codegen for `torch.sqrt` was dtype aware. This PR updates most of the `libdevice`/`tl.math` ops to support dtype-aware codegen as well. This is often necessary to get correct code when `config.triton.codegen_upcast_to_fp32=False`, as most Triton math ops do not support float16/bfloat16.

This PR enables dtype aware codegen via the `maybe_upcast_float32` decorator. This wraps `TritonOverrides` macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set `convert_output=False` and skip the output cast.

# Test Plan
Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the `maybe_upcast_float32` decorator, and stored in the new `OpDtypeSupport` class. In each new test, we search the generated code for upcasts/downcasts using a regex.

Also added a unit test for `OpDtypeSupport` which checks that we have correct dtype info for ops that require upcasts.

This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file.

Pull Request resolved: #140864
Approved by: https://github.com/eellison, https://github.com/arui-meta

Co-authored-by: eellison <elias.ellison@gmail.com>
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
AmdSampsa pushed a commit to AmdSampsa/pytorch that referenced this pull request Dec 9, 2024
…ytorch#140864)

# Feature
Previously, only the codegen for `torch.sqrt` was dtype aware. This PR updates most of the `libdevice`/`tl.math` ops to support dtype-aware codegen as well. This is often necessary to get correct code when `config.triton.codegen_upcast_to_fp32=False`, as most Triton math ops do not support float16/bfloat16.

This PR enables dtype aware codegen via the `maybe_upcast_float32` decorator. This wraps `TritonOverrides` macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set `convert_output=False` and skip the output cast.

# Test Plan
Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the `maybe_upcast_float32` decorator, and stored in the new `OpDtypeSupport` class. In each new test, we search the generated code for upcasts/downcasts using a regex.

Also added a unit test for `OpDtypeSupport` which checks that we have correct dtype info for ops that require upcasts.

This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file.

Pull Request resolved: pytorch#140864
Approved by: https://github.com/eellison, https://github.com/arui-meta

Co-authored-by: eellison <elias.ellison@gmail.com>
@github-actions github-actions bot deleted the brister/dtype_codegen branch January 8, 2025 02:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0