8000 [inductor] use smaller RBLOCK for expensive reduction kernels by shunting314 · Pull Request #126477 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] use smaller RBLOCK for expensive reduction kernels #126477

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

shunting314
Copy link
Contributor
@shunting314 shunting314 commented May 16, 2024

Stack from ghstack (oldest at bottom):

Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang

Copy link
pytorch-bot bot commented May 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/126477

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit 32e418d with merge base 53f73cd (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

…els"

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
…els"


Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
shunting314 added a commit that referenced this pull request May 17, 2024
Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this motivation to do live variable analysis.. going to wait to see how perf results look.

@@ -1480,12 +1485,32 @@ def _reduction_configs(
assert len(size_hints) == 2
rnumel = size_hints[-1]

MAX_RBLOCK = 2048
if (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

were there any register spills on old heuristics here ?

8000

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no. I checked CompiledKernel.n_spills, it's 0.

…els"


Triton sometimes uses less registers for more expensive kernel which results in worse perf ( #126463 ). This may make inductor end up with a sub-optimal config. Use a smaller max RBLOCK if the reduction potentially need many registers.

Will run perf test..


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang

[ghstack-poisoned]
@shunting314
Copy link
Contributor Author

Perf test on this PR is roughly neutral.

I'll land this one and debug the perf/compilation time regress for the PR tuning partitioner.

@shunting314
Copy link
C 8000 ontributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 22, 2024
@shunting314 shunting314 added the topic: not user facing topic category label May 22, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/shunting314/146/head branch June 22, 2024 05:32
desai0007 pushed a commit to desai0007/test-repo-pytorch that referenced this pull request Feb 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants
0