-
Notifications
You must be signed in to change notification settings - Fork 24.8k
Recheck autotune cache on static cuda launcher load #153565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2ecc515
298b8fc
41256eb
8000841
81ebd18
608ead4
cc84ffc
12ebdf7
1b0a78e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -187,6 +187,55 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid): | |
f.write(f"{kernel_name} | {args_str} | {grid!r}\n") | ||
|
||
|
||
def check_autotune_cache( | ||
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any] | ||
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]: | ||
""" | ||
Given a list of configs, checks autotune cache and return metadata | ||
""" | ||
autotune_cache = None | ||
autotune_cache_info = {} | ||
disabled = inductor_meta.get("force_disable_caches", False) | ||
if ( | ||
not disabled | ||
and filename is not None | ||
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) | ||
and not os.environ.get("TRITON_INTERPRET", "0") == "1" | ||
): | ||
configs_hash = hash_configs(configs) | ||
|
||
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) | ||
if autotune_cache: | ||
if best_config := autotune_cache.read_best(inductor_meta, configs): | ||
configs = [best_config] | ||
autotune_cache_info["best_config"] = triton_config_to_hashable( | ||
best_config | ||
) | ||
autotune_cache_info["autotune_cache_state"] = "hit" | ||
|
||
else: | ||
autotune_cache_info["autotune_cache_state"] = "miss" | ||
autotune_cache_info["num_configs"] = len(configs) | ||
if inductor_meta.get("coordinate_descent_tuning"): | ||
autotune_cache_info["coordesc_tuning"] = True | ||
if len(configs) == 1: | ||
# This is the config that coordinate descent tuning started at, which | ||
# is not the same as the final config chosen (i.e. only_config, best_config) | ||
autotune_cache_info["coordesc_tuning_start_config"] = ( | ||
triton_config_to_hashable(configs[0]) | ||
) | ||
else: | ||
if len(configs) == 1: | ||
autotune_cache_info["autotune_cache_state"] = "only 1 config" | ||
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0]) | ||
|
||
if disabled: | ||
autotune_cache_info["autotune_cache_state"] = "force_disabled" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just moved code, but in the case of len(configs)==1 AND disabled is it intentional to overwrite the autotune_cache_state? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, force disable should override all the other logging |
||
log.debug("autotune caching is disabled by config.force_disable_caches") | ||
|
||
return configs, autotune_cache, autotune_cache_info | ||
|
||
|
||
class CachingAutotuner(KernelInterface): | ||
""" | ||
Simplified version of Triton autotuner that has no invalidation | ||
|
@@ -298,6 +347,41 @@ def is_statically_launchable(self): | |
isinstance(x, StaticTritonCompileResult) for x in self.compile_results | ||
) | ||
|
||
def recheck_autotune_cache( | ||
self, reload_kernel_from_src: Callable[[], CachingAutotuner] | ||
) -> None: | ||
""" | ||
On cache load on static autotuner, we need to recheck the autotune cache, since | ||
a best config could have been found from a previous run | ||
""" | ||
assert self.is_statically_launchable() | ||
|
||
configs = [result.config for result in self.compile_results] | ||
|
||
(cached_configs, _, autotune_cache_info) = check_autotune_cache( | ||
configs, self.filename, self.inductor_meta | ||
) | ||
self.autotune_cache_info = autotune_cache_info | ||
# I.e. there was an autotune cache hit | ||
if len(cached_configs) == 1 and len(configs) > 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need to check len(configs) > 1 ? we'd still coordinate_descent_tune with a single config right ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If len(configs) == 1, we don't need to prune the list of configs or compile results, so there's no need to loop through the list. If coordesc tuning is on, then we'll start coordesc tuning immediately. |
||
best_config = cached_configs[0] | ||
# Grab the best compiled config, if it's in the list of available ones | ||
best_config_hash = triton_config_to_hashable(best_config) | ||
|
||
for compile_result in self.compile_results: | ||
if triton_config_to_hashable(compile_result.config) == best_config_hash: | ||
Comment on lines
+371
to
+372
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just so I am following, if we were coordinate descent tuning, and happened to have the best config from start, then these would be equal ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct: the only case where it falls out of this for loop is if coordesc tuning finds a best_config that wasn't one of the precompiled options. In that case, we haven't saved anything in the cache and need to recompile. |
||
self.compile_results = [compile_result] | ||
return | ||
|
||
# If the best config isn't in our list of compile results, | ||
# it's likely because it was found by coordesc after the cache | ||
# already saved | ||
if best_config.found_by_coordesc: | ||
with dynamo_timed("CachingAutotuner.slow_precompile_config"): | ||
if self.fn.fn is None: | ||
self.fn = reload_kernel_from_src().fn | ||
Comment on lines
+379
to
+382
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason we need to rely on There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this can be an assert yes, because if best_config isn't in the list of compiled configs it should always be because of coordesc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That said, I was a little scared that I might have missed a case where it's possible for best_config to be not in our list... and it doesn't seem particularly helpful to crash in that case? We should just re-autotune and pretend it was a cache miss IMO |
||
self.compile_results = [self._precompile_config(best_config)] | ||
|
||
def set_compile_info( | ||
self, compile_id: Optional[CompileId], is_backward: bool | ||
) -> None: | ||
|
@@ -1713,47 +1797,9 @@ def cached_autotune( | |
assert len(configs) == 1 or filename | ||
inductor_meta = {} if inductor_meta is None else inductor_meta | ||
|
||
disabled = inductor_meta.get("force_disable_caches", False) | ||
|
||
# on disk caching logic and/or remote caching | ||
autotune_cache = None | ||
autotune_cache_info = {} | ||
if ( | ||
not disabled | ||
and filename is not None | ||
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning")) | ||
and not os.environ.get("TRITON_INTERPRET", "0") == "1" | ||
): | ||
configs_hash = hash_configs(configs) | ||
|
||
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash) | ||
if autotune_cache: | ||
if best_config := autotune_cache.read_best(inductor_meta, configs): | ||
configs = [best_config] | ||
autotune_cache_info["best_config"] = triton_config_to_hashable( | ||
best_config | ||
) | ||
autotune_cache_info["autotune_cache_state"] = "hit" | ||
else: | ||
autotune_cache_info["autotune_cache_state"] = "miss" | ||
autotune_cache_info["num_configs"] = len(configs) | ||
if inductor_meta.get("coordinate_descent_tuning"): | ||
autotune_cache_info["coordesc_tuning"] = True | ||
if len(configs) == 1: | ||
# This is the config that coordinate descent tuning started at, which | ||
# is not the same as the final config chosen (i.e. only_config, best_config) | ||
autotune_cache_info["coordesc_tuning_start_config"] = ( | ||
triton_config_to_hashable(configs[0]) | ||
) | ||
else: | ||
if len(configs) == 1: | ||
autotune_cache_info["autotune_cache_state"] = "only 1 config" | ||
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0]) | ||
|
||
if disabled: | ||
autotune_cache_info["autotune_cache_state"] = "force_disabled" | ||
log.debug("autotune caching is disabled by config.force_disable_caches") | ||
|
||
configs, autotune_cache, autotune_cache_info = check_autotune_cache( | ||
configs, filename, inductor_meta | ||
) | ||
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ()) | ||
optimize_mem = inductor_meta.pop("optimize_mem", True) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so if we only set remove autotune cache disabled, in this case, we wont honor that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The individual cache settings for AutotuneCache are checked separately, here on call to
AutotuneCache.create
:pytorch/torch/_inductor/runtime/autotune_cache.py
Line 63 in a898696