-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[Inductor] Expand dtype aware codegen for libdevice and tl.math ops #140864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
[ghstack-poisoned]
…:pytorch/pytorch into brister/dtype_codegen
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/140864
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit 9c5d202 with merge base ed77901 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
Adds the remaining unimplemented ops as well as an assertion failure if someone adds a new op without a dtype rule. We test all unique pointwise operators registered as lowerings which have an opinfo. There will be some follow ups for this to work well with both `codegen_upcast_to_fp32` as True and False. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang aakhundov [ghstack-poisoned]
…o brister/dtype_codegen
@pytorchbot successfully started a revert job. Check the current status here. |
@blaine-rister your PR has been successfully reverted. |
…th ops (#140864)" This reverts commit 80ca6dd. Reverted #140864 on behalf of https://github.com/atalman due to failing internally ([comment](#140864 (comment)))
@pytorchbot revert -m "Nondetermistic test is failing internally" |
❌ 🤖 pytorchbot command failed:
Try |
@pytorchbot merge |
✅ Deploy Preview for chimerical-cranachan-793287 ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
@blaine-rister has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Canceling merge until I can confirm that internal tests are fixed. |
…th ops (#140864)" This reverts commit 80ca6dd. Reverted #140864 on behalf of https://github.com/atalman due to failing internally ([comment](#140864 (comment)))
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 2 checks: pull / linux-focal-py3.13-clang10 / build, pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
@pytocrhbot merge -i |
@pytorchbot merge -i |
Merge startedYour change will be merged while i F438 gnoring the following 2 checks: pull / linux-focal-py3.13-clang10 / build, pull / linux-focal-py3_9-clang9-xla / test (xla, 1, 1, lf.linux.12xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…140864) # Feature Previously, only the codegen for `torch.sqrt` was dtype aware. This PR updates most of the `libdevice`/`tl.math` ops to support dtype-aware codegen as well. This is often necessary to get correct code when `config.triton.codegen_upcast_to_fp32=False`, as most Triton math ops do not support float16/bfloat16. This PR enables dtype aware codegen via the `maybe_upcast_float32` decorator. This wraps `TritonOverrides` macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set `convert_output=False` and skip the output cast. # Test Plan Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the `maybe_upcast_float32` decorator, and stored in the new `OpDtypeSupport` class. In each new test, we search the generated code for upcasts/downcasts using a regex. Also added a unit test for `OpDtypeSupport` which checks that we have correct dtype info for ops that require upcasts. This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file. Pull Request resolved: #140864 Approved by: https://github.com/eellison, https://github.com/arui-meta Co-authored-by: eellison <elias.ellison@gmail.com>
…th ops (pytorch#140864)" This reverts commit 80ca6dd. Reverted pytorch#140864 on behalf of https://github.com/atalman due to failing internally ([comment](pytorch#140864 (comment)))
…ytorch#140864) # Feature Previously, only the codegen for `torch.sqrt` was dtype aware. This PR updates most of the `libdevice`/`tl.math` ops to support dtype-aware codegen as well. This is often necessary to get correct code when `config.triton.codegen_upcast_to_fp32=False`, as most Triton math ops do not support float16/bfloat16. This PR enables dtype aware codegen via the `maybe_upcast_float32` decorator. This wraps `TritonOverrides` macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we set `convert_output=False` and skip the output cast. # Test Plan Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the `maybe_upcast_float32` decorator, and stored in the new `OpDtypeSupport` class. In each new test, we search the generated code for upcasts/downcasts using a regex. Also added a unit test for `OpDtypeSupport` which checks that we have correct dtype info for ops that require upcasts. This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file. Pull Request resolved: pytorch#140864 Approved by: https://github.com/eellison, https://github.com/arui-meta Co-authored-by: eellison <elias.ellison@gmail.com>
Feature
Previously, only the codegen for
torch.sqrt
was dtype aware. This PR updates most of thelibdevice
/tl.math
ops to support dtype-aware codegen as well. This is often necessary to get correct code whenconfig.triton.codegen_upcast_to_fp32=False
, as most Triton math ops do not support float16/bfloat16.This PR enables dtype aware codegen via the
maybe_upcast_float32
decorator. This wrapsTritonOverrides
macros to upcast arguments to float32, and downcast the result back to the original dtype. The exception is for ops that return booleans, in which case we setconvert_output=False
and skip the output cast.Test Plan
Added CI tests for all the new ops. The list of ops to test is automatically generated based on uses of the
maybe_upcast_float32
decorator, and stored in the newOpDtypeSupport
class. In each new test, we search the generated code for upcasts/downcasts using a regex.Also added a unit test for
OpDtypeSupport
which checks that we have correct dtype info for ops that require upcasts.This PR also moves some existing tests around, to collect all the dtype aware codegen tests in one file.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov