-
Notifications
You must be signed in to change notification settings - Fork 24.2k
basic compile support for grouped_mm #153384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/bdhirsh/661/base
Are you sure you want to change the base?
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153384
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 2 Unrelated FailuresAs of commit 4e78bbc with merge base daca611 ( NEW FAILURE - The following job has failed:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
grouped_mm is used in torchtitan, this adds just enough support in compile to allow inductor to lower it as a fallback kernel. I imagine that at some point in the future it may be valuable to get inductor to support templating grouped_mm, although this PR just provides basic support. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov ngimel eellison cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile ``` CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile ``` [ghstack-poisoned]
grouped_mm is used in torchtitan, this adds just enough support in compile to allow inductor to lower it as a fallback kernel. I imagine that at some point in the future it may be valuable to get inductor to support templating grouped_mm, although this PR just provides basic support. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov ngimel eellison cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
@@ -13741,6 +13742,50 @@ def forward( | |||
) | |||
torch._inductor.aot_compile(traced, inputs) | |||
|
|||
@skipCUDAIf(not SM90OrLater, "Requires sm90") | |||
@requires_gpu() | |||
@config.patch(implicit_fallbacks=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed the lowering I added in inductor - I realized that implicit_fallbacks actually defaults to True
, but was set to false in this test suite (thanks @ngimel asking why I needed the lowering)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have an explicit striding requirement ? if so, we should add them - potentially as require_exact_strides, requires_contiguous, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. maybe @ngimel would know? (does grouped_mm only support contiguous inputs?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For grouped_mm we don't other than one of the dimensions should be contiguous, similar to regular mm.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bdhirsh can you add make_fallback(aten._grouped_mm, require_dense)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison what would be the correct way of handling this for custom ops provided by third-party libraries?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For custom ops provided by third-party libraries, we are trying to match the eager strides by default. @zou3519 is going to turn that on (i forget if he did or not yet).
Otherwise we have these tags:
pytorch/aten/src/ATen/native/tags.yaml
Lines 45 to 67 in 236b08c
- tag: needs_exact_strides | |
desc: | | |
This tag indicates that the operator should be passed Tensors following | |
the same strides as observed in eager when compiled in inductor. | |
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout} | |
can apply; if multiple are assigned then we assume the most restrictive one. | |
- tag: needs_contiguous_strides | |
desc: | | |
This tag indicates that the operator should be passed contiguous Tensors. | |
Failure to do so will result in undefined behavior. | |
- tag: needs_fixed_stride_order | |
desc: | | |
This tag indicates that the operator should be passed Tensors following | |
the same stride permutation as observed in eager when compiled in inductor. | |
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout} | |
can apply; if multiple are assigned then we assume the most restrictive one. | |
- tag: flexible_layout | |
desc: | | |
This tag indicates that the custom operator can accept inputs with varying | |
strides/storage_offset and that when compiled, Inductor is allowed to change | |
the strides/storage_offset of inputs to the custom operator. | |
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout} | |
can apply; if multiple are assigned then we assume the most restrictive one. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah interesting, so there's no "require_dense" equivalent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theres not, we could add one if someone wanted
grouped_mm is used in torchtitan, this adds just enough support in compile to allow inductor to lower it as a fallback kernel. I imagine that at some point in the future it may be valuable to get inductor to support templating grouped_mm, although this PR just provides basic support. cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @ngimel @eellison
Stack from ghstack (oldest at bottom):
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov