8000 [Inductor] Record Triton’s Base32 Cache Key in `.best_config` for Debugging by fulvius31 · Pull Request #147019 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor] Record Triton’s Base32 Cache Key in .best_config for Debugging #147019

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

fulvius31
Copy link
Contributor
@fulvius31 fulvius31 commented Feb 12, 2025

Modified TorchInductor’s autotuning flow so that each best_config JSON file also includes the Triton “base32” (or base64) cache key.

Motivation

Debugging & Analysis: With this change, we can quickly identify which compiled binary and IRs belongs to a given best config.
The impact is minimal since it is only an extra field in .best_config. It can help advanced performance tuning or kernel-level debugging.

Also, since Triton already stores cubin/hsaco in its cache, developers/researchers can avoid to set store_cubin = True since they can get the cubin/hsaco in the Triton cache and with the code provided in this PR, they can easily match the best_config with the right Triton cache directory for the "best" kernel.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @davidberard98

Copy link
8000
pytorch-bot bot commented Feb 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147019

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 51edce0 with merge base 6c089f5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@fulvius31
Copy link
Contributor Author

@pytorchbot label "ciflow/inductor"

Copy link
pytorch-bot bot commented Feb 12, 2025

To add these label(s) (ciflow/inductor) to the PR, please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@fulvius31
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Feb 12, 2025
@janeyx99 janeyx99 requested a review from eellison February 13, 2025 18:43
@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 13, 2025
@davidberard98 davidberard98 self-requested a review February 17, 2025 23:14
@davidberard98
Copy link
Contributor

@fulvius31 would it be possible to add a test for this?

@fulvius31
Copy link
Contributor Author

@fulvius31 would it be possible to add a test for this?

Hey @davidberard98 , do you mean adding a test to verify that the new field is properly recorded in the best config JSON file?

Copy link
Contributor
@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fulvius31

  1. I tested this with a newer version of Triton (a commit from Jan 28), and for the commit hash recorded in .best_config, I can't seem to find a corresponding Triton hash. Should I expect the corresponding Triton hash to be somewhere in the /tmp/torchinductor_${USER} directory or is it somewhere else? (And if it is in /tmp/torchinductor... then I think this might be broken for newer versions of Triton.
  2. re: adding a test - yes, if it's not too hard to add a test, it would be great if we could check (a) whether the .best_config exists and contains the expected cache key, and (b) whether the cache key corresponds to a real inductor file somewhere.

@fulvius31
Copy link
Contributor Author

@fulvius31

  1. I tested this with a newer version of Triton (a commit from Jan 28), and for the commit hash recorded in .best_config, I can't seem to find a corresponding Triton hash. Should I expect the corresponding Triton hash to be somewhere in the /tmp/torchinductor_${USER} directory or is it somewhere else? (And if it is in /tmp/torchinductor... then I think this might be broken for newer versions of Triton.

The corresponding triton hash should be on ~/.triton/cache/ which is the default TRITON_CACHE_DIR

  1. re: adding a test - yes, if it's not too hard to add a test, it would be great if we could check (a) whether the .best_config exists and contains the expected cache key, and (b) whether the cache key corresponds to a real inductor file somewhere.

I can work on it, absolutely. Can I open a new PR for that or you want me to include in this one?

@fulvius31
Copy link
Contributor Author

@davidberard98 were you able to test it?

@fulvius31 fulvius31 requested a review from a team as a code owner February 25, 2025 19:35
@fulvius31
Copy link
Contributor Author

I've added a test for this PR as suggested, @davidberard98

Copy link
Contributor
@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

letting @davidberard98 take this one

@davidberard98
Copy link
Contributor

@fulvius31
first of all, thanks for the patience - we've been busy prepping for triton 3.3 / pytorch 2.7 release

Here's what my tmp directory looks like when I run your test
https://gist.github.com/davidberard98/8adbd2e282b322f7c8a9ddc11912f631

(I put a breakpoint at the end of your test to look at the directory before the test gets removed)

And my ~/.triton/cache is empty. (As I expect it to be - I think Inductor puts triton kernels in the /tmp/torchinductor_$USER directory)

@fulvius31
Copy link
Contributor Author

@fulvius31
first of all, thanks for the patience - we've been busy prepping for triton 3.3 / pytorch 2.7 release

Here's what my tmp directory looks like when I run your test
https://gist.github.com/davidberard98/8adbd2e282b322f7c8a9ddc11912f631

(I put a breakpoint at the end of your test to look at the directory before the test gets removed)

And my ~/.triton/cache is empty. (As I expect it to be - I think Inductor puts triton kernels in the /tmp/torchinductor_$USER directory)

No worries at all !

Yes, the default triton cache directory depends if you installed triton from source or from torch. Btw, have you found the triton cache directory in the new field I added for best_config ?

@davidberard98
Copy link
Contributor

@fulvius31

Btw, have you found the triton cache directory in the new field I added for best_config ?

No - all the triton cache artifacts are shown in the gist I pasted

If you delete your .triton/cache and re-run, do you still find the cache dir matching the .best_config file?

@fulvius31
Copy link
Contributor Author
fulvius31 commented Feb 28, 2025

@fulvius31

Btw, have you found the triton cache directory in the new field I added for best_config ?

No - all the triton cache artifacts are shown in the gist I pasted

Can you paste the content of best_config using the code with the edits I made ? @davidberard98

If you delete your .triton/cache and re-run, do you still find the cache dir matching the .best_config file?

Yes, I do. Whenever the best_config is found, there is always a corresponding triton cache dir.

@fulvius31
Copy link
Contributor Author

@davidberard98 Based on your gist, on 6bb45a2ecc8066b0a2b09b298dcfb30d04cfc6576f5d1361a82266feb7bf4653.best_config you should have as triton_cache_dir value either s1SJRsxpmxFBKMcqqwNibsuhl2wvju606aNZizbyINk or ThWw5E-MLq3qaRXcbyUkc9UNa0PwhtI8B0OlYh9pKEU. Does the test I've included works for you?

Copy link
Contributor
@davidberard98 davidberard98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, I see - thanks for pointing me to the right place. Approving for now, but I had a follow-up question: where does the name of the .best_config file come from? (i.e. [best_config_hash].best_config - where does [best_config_hash] come from?)

@fulvius31
Copy link
Contributor Author

ah, I see - thanks for pointing me to the right place. Approving for now, but I had a follow-up question: where does the name of the .best_config file come from? (i.e. [best_config_hash].best_config - where does [best_config_hash] come from?)

Thank you for the review, merge it when you can!

So, as far as I understood, the best_config_hash is derived from _prepare_key(filename) that is basically another hash of the generated .py file’s own SHA-256 name (the triton code, in your case cxxckjomvbgwdkydoxzaluzbqgvq4icxezf3rabmvf4deu4wpwwz.py ) plus a "salt" config.cache_key_tag

@davidberard98
Copy link
Contributor

@pytorchbot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 3, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Comment with id 2697308151 not found

Details for Dev Infra team Raised by workflow job

@fulvius31
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@fulvius31 fulvius31 deleted the inductor_cache branch March 4, 2025 13:20
@clee2000
Copy link
Contributor
clee2000 commented Mar 4, 2025

@pytorchbot revert -m "broke inductor test inductor/test_max_autotune.py::TestMaxAutotune::test_cat_max_autotune_extern GH job link HUD commit link on inductor workfl 9E88 ow and rocm workflow" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@fulvius31 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Mar 4, 2025
… for Debugging (#147019)"

This reverts commit e3e45d9.

Reverted #147019 on behalf of https://github.com/clee2000 due to broke inductor test inductor/test_max_autotune.py::TestMaxAutotune::test_cat_max_autotune_extern [GH job link](https://github.com/pytorch/pytorch/actions/runs/13653495421/job/38171259603) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/e3e45d90d8578083da8b51a3b1d911e9a4523e5b) on inductor workflow and rocm workflow ([comment](#147019 (comment)))
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Mar 4, 2025
pytorchmergebot pushed a commit to min-jean-cho/pytorch that referenced this pull request Mar 5, 2025
…ugging (pytorch#147019)

Modified  TorchInductor’s autotuning flow so that each `best_config` JSON file also includes the Triton “base32” (or base64) cache key.

**Motivation**

Debugging & Analysis: With this change, we can quickly identify which compiled binary and IRs belongs to a given best config.
The impact is minimal since it is only an extra field in .best_config. It can help advanced performance tuning or kernel-level debugging.

Also, since Triton already stores cubin/hsaco in its cache, developers/researchers can avoid to set `store_cubin = True` since they can get the cubin/hsaco in the Triton cache and with the code provided in this PR, they can easily match the best_config with the right Triton cache directory for the "best" kernel.

Pull Request resolved: pytorch#147019
Approved by: https://github.com/davidberard98
@fulvius31 fulvius31 restored the inductor_cache branch March 5, 2025 09:47
@fulvius31
Copy link
Contributor Author

Hi @clee2000 , should I re-open another PR to push the commit with the fix ?

@fulvius31
Copy link
Contributor Author

I was able to get the error using rocm : https://termbin.com/vr1i

And I was able to fix it : fulvius31@0a63a54

the test passes : https://termbin.com/pt06

@clee2000 @davidberard98

pytorchmergebot pushed a commit that referenced this pull request Apr 24, 2025
…ging (#148981)

This is a follow-up PR of the reverted one #147019 :

Modified TorchInductor’s autotuning flow so that each best_config JSON file also includes the Triton “base32” (or base64) cache key.

Motivation

Debugging & Analysis: With this change, we can quickly identify which compiled binary and IRs belongs to a given best config.
The impact is minimal since it is only an extra field in .best_config. It can help advanced performance tuning or kernel-level debugging.

Also, since Triton already stores cubin/hsaco in its cache, developers/researchers can avoid to set store_cubin = True since they can get the cubin/hsaco in the Triton cache and with the code provided in this PR, they can easily match the best_config with the right Triton cache directory for the "best" kernel.

Pull Request resolved: #148981
Approved by: https://github.com/davidberard98
wangkuiyi pushed a commit to wangkuiyi/pytorch that referenced this pull request Apr 25, 2025
…ging (pytorch#148981)

This is a follow-up PR of the reverted one pytorch#147019 :

Modified TorchInductor’s autotuning flow so that each best_config JSON file also includes the Triton “base32” (or base64) cache key.

Motivation

Debugging & Analysis: With this change, we can quickly identify which compiled binary and IRs belongs to a given best config.
The impact is minimal since it is only an extra field in .best_config. It can help advanced performance tuning or kernel-level debugging.

Also, since Triton already stores cubin/hsaco in its cache, developers/researchers can avoid to set store_cubin = True since they can get the cubin/hsaco in the Triton cache and with the code provided in this PR, they can easily match the best_config with the right Triton cache directory for the "best" kernel.

Pull Request resolved: pytorch#148981
Approved by: https://github.com/davidberard98
@ZainRizvi
Copy link
Contributor

@pytorchbot revert -c ghfirst -m "Sorry but this is breaking internally. @davidberard98 can you please help get these changes validated? Details in D73628297. To validate the fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts"

This failure is extra weird since it seems to be failing on a test that's on github.

Failure is: AttributeError: 'function' object has no attribute 'cache_hash'

And the stack is pointing to this line of code that's also on github right here:

_inductor/runtime/triton_heuristics.py", line 877, in autotune_to_one_config
    triton_cache_hash=launcher.cache_hash,

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

Reverting PR 147019 failed

Reason: Command git -C /home/runner/work/pytorch/pytorch revert --no-edit e3e45d90d8578083da8b51a3b1d911e9a4523e5b returned non-zero exit code 1

Auto-merging test/run_test.py
CONFLICT (content): Merge conflict in test/run_test.py
Auto-merging torch/_inductor/runtime/autotune_cache.py
Auto-merging torch/_inductor/runtime/triton_heuristics.py
error: could not revert e3e45d90d85... [Inductor] Record Triton’s Base32 Cache Key in `.best_config` for Debugging (#147019)
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git revert --continue".
hint: You can instead skip this commit with "git reve
D5C5
rt --skip".
hint: To abort and get back to the state before "git revert",
hint: run "git revert --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@ZainRizvi
Copy link
Contributor

oh, clicked the wrong link in the diff. The follow up was the one to revert...

rec pushed a commit to rec/pytorch that referenced this pull request Apr 25, 2025
…ging (pytorch#148981)

This is a follow-up PR of the reverted one pytorch#147019 :

Modified TorchInductor’s autotuning flow so that each best_config JSON file also includes the Triton “base32” (or base64) cache key.

Motivation

Debugging & Analysis: With this change, we can quickly identify which compiled binary and IRs belongs to a given best config.
The impact is minimal since it is only an extra field in .best_config. It can help advanced performance tuning or kernel-level debugging.

Also, since Triton already stores cubin/hsaco in its cache, developers/researchers can avoid to set store_cubin = True since they can get the cubin/hsaco in the Triton cache and with the code provided in this PR, they can easily match the best_config with the right Triton cache directory for the "best" kernel.

Pull Request resolved: pytorch#148981
Approved by: https://github.com/davidberard98
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source Reverted topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants
0