8000 basic compile support for grouped_mm by bdhirsh · Pull Request #153384 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

basic compile support for grouped_mm #153384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: gh/bdhirsh/661/base
Choose a base branch
from

Conversation

bdhirsh
Copy link
Contributor
@bdhirsh bdhirsh commented May 12, 2025

grouped_mm is used in torchtitan, this adds just enough support in compile to allow inductor to lower it as a fallback kernel. I imagine that at some point in the future it may be valuable to get inductor to support templating grouped_mm, although this PR just provides basic support. cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @ngimel @eellison

Stack from ghstack (oldest at bottom):

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov

Copy link
pytorch-bot bot commented May 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153384

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commit 4e78bbc with merge base daca611 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: bfb03e9
Pull Request resolved: #153384
@bdhirsh bdhirsh requested review from eellison and ngimel May 12, 2025 17:35
bdhirsh added a commit to pytorch/torchtitan that referenced this pull request May 12, 2025
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
@bdhirsh bdhirsh added the release notes: python_frontend python frontend release notes category label May 12, 2025
grouped_mm is used in torchtitan, this adds just enough support in compile to allow inductor to lower it as a fallback kernel. I imagine that at some point in the future it may be valuable to get inductor to support templating grouped_mm, although this PR just provides basic support. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov ngimel eellison 






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: fe4107e
Pull Request resolved: #153384
bdhirsh added a commit to pytorch/torchtitan that referenced this pull request May 12, 2025
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
bdhirsh added a commit to pytorch/torchtitan that referenced this pull request May 12, 2025
This PR + pytorch/pytorch#153384 is enough to get torchtitan running for me with llama4 and compile
```
CONFIG_FILE="./torchtitan/experiments/llama4/train_configs/debug_model.toml" ./run_train.sh --training.compile
```




[ghstack-poisoned]
grouped_mm is used in torchtitan, this adds just enough support in compile to allow inductor to lower it as a fallback kernel. I imagine that at some point in the future it may be valuable to get inductor to support templating grouped_mm, although this PR just provides basic support. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov ngimel eellison 






cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request May 12, 2025
ghstack-source-id: 3c696c5
Pull Request resolved: #153384
@@ -13741,6 +13742,50 @@ def forward(
)
torch._inductor.aot_compile(traced, inputs)

@skipCUDAIf(not SM90OrLater, "Requires sm90")
@requires_gpu()
@config.patch(implicit_fallbacks=True)
Copy link
Contributor Author
@bdhirsh bdhirsh May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the lowering I added in inductor - I realized that implicit_fallbacks actually defaults to True, but was set to false in this test suite (thanks @ngimel asking why I needed the lowering)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an explicit striding requirement ? if so, we should add them - potentially as require_exact_strides, requires_contiguous, etc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. maybe @ngimel would know? (does grouped_mm only support contiguous inputs?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For grouped_mm we don't other than one of the dimensions should be contiguous, similar to regular mm.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh can you add make_fallback(aten._grouped_mm, require_dense)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison what would be the correct way of handling this for custom ops provided by third-party libraries?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For custom ops provided by third-party libraries, we are trying to match the eager strides by default. @zou3519 is going to turn that on (i forget if he did or not yet).

Otherwise we have these tags:

- tag: needs_exact_strides
desc: |
This tag indicates that the operator should be passed Tensors following
the same strides as observed in eager when compiled in inductor.
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one.
- tag: needs_contiguous_strides
desc: |
This tag indicates that the operator should be passed contiguous Tensors.
Failure to do so will result in undefined behavior.
- tag: needs_fixed_stride_order
desc: |
This tag indicates that the operator should be passed Tensors following
the same stride permutation as observed in eager when compiled in inductor.
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one.
- tag: flexible_layout
desc: |
This tag indicates that the custom operator can accept inputs with varying
strides/storage_offset and that when compiled, Inductor is allowed to change
the strides/storage_offset of inputs to the custom operator.
Only one of {needs_exact_strides, needs_contiguous_strides, needs_fixed_stride_order, flexible_layout}
can apply; if multiple are assigned then we assume the most restrictive one.
.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah interesting, so there's no "require_dense" equivalent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theres not, we could add one if someone wanted

@albanD albanD removed their request for review May 12, 2025 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0