-
Notifications
You must be signed in to change notification settings - Fork 24.2k
Recheck autotune cache on static cuda launcher load #153565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/jamesjwu/152/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/153565
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New FailuresAs of commit 81ebd18 with merge base a2e2f90 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's up with the gloo change?
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0]) | ||
|
||
if disabled: | ||
autotune_cache_info["autotune_cache_state"] = "force_disabled" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just moved code, but in the case of len(configs)==1 AND disabled is it intentional to overwrite the autotune_cache_state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, force disable should override all the other logging
@@ -187,6 +187,55 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): | |||
f.write(f"{kernel_name} | {args_str} | {grid!r}\n") | |||
|
|||
|
|||
def check_autotune_cache( | |||
configs, filename, inductor_meta |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we annotate these?
@@ -298,6 +347,39 @@ def is_statically_launchable(self): | |||
isinstance(x, StaticTritonCompileResult) for x in self.compile_results | |||
) | |||
|
|||
def recheck_autotune_cache(self, reload_kernel_from_src) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Annotate reload_kernel_from_src
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 gloo, also have a couple questions
) | ||
self.autotune_cache_info = autotune_cache_info | ||
# I.e. there was an autotune cache hit | ||
if len(cached_configs) == 1 and len(configs) > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to check len(configs) > 1 ? we'd still coordinate_descent_tune with a single config right ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If len(configs) == 1, we don't need to prune the list of configs or compile results, so there's no need to loop through the list. If coordesc tuning is on, then we'll start coordesc tuning immediately.
if best_config.found_by_coordesc: | ||
with dynamo_timed("CachingAutotuner.slow_precompile_config"): | ||
if self.fn.fn is None: | ||
self.fn = reload_kernel_from_src().fn |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason we need to rely on best_config.found_by_coordesc
? should this be an assert, or should we always be reloading ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be an assert yes, because if best_config isn't in the list of compiled configs it should always be because of coordesc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That said, I was a little scared that I might have missed a case where it's possible for best_config to be not in our list... and it doesn't seem particularly helpful to crash in that case? We should just re-autotune and pretend it was a cache miss IMO
for compile_result in self.compile_results: | ||
if triton_config_to_hashable(compile_result.config) == best_config_hash: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just so I am following, if we were coordinate descent tuning, and happened to have the best config from start, then these would be equal ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct: the only case where it falls out of this for loop is if coordesc tuning finds a best_config that wasn't one of the precompiled options. In that case, we haven't saved anything in the cache and need to recompile.
[ghstack-poisoned]
Fixed gloo |
[ghstack-poisoned]
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test Details for Dev Infra teamRaised by workflow job |
""" | ||
autotune_cache = None | ||
autotune_cache_info = {} | ||
disabled = inductor_meta.get("force_disable_caches", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so if we only set remove autotune cache disabled, in this case, we wont honor that ?
Stack from ghstack (oldest at bottom):
When loading statically launchable triton kernels from FxGraphCache, since we don't instantiate a CachingAutotuner like we do normally, we need to recheck the autotune cache based on the existing compile results. If we get a hit, we take the compile result whose config matches the best config.
Sometimes, the best config will have been from coordinate descent tuning. In this case, FxGraphCache today does not cache the resulting triton kernel, neither with static or without static cuda launcher. This is because coordinate descent tuning happens at runtime, and if the best config happens to not be one of the precompiled configs.
Test Plan:
New unit test that failed before
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov