8000 Recheck autotune cache on static cuda launcher load by jamesjwu · Pull Request #153565 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Recheck autotune cache on static cuda launcher load #153565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions test/inductor/test_codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2183,6 +2183,55 @@ def reset(self):
torch._dynamo.reset()
clear_inductor_caches()

@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@unittest.skipIf(
TEST_WITH_ROCM, "Requires static cuda launcher, which does not support ROCM"
)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@config.patch({"autotune_local_cache": False})
@config.patch({"autotune_remote_cache": True})
@config.patch({"bundled_autotune_remote_cache": False})
@config.patch({"max_autotune": True})
@config.patch(
{"compile_threads": 1}
) # Worker processes do not register PatchCaches() properly
def test_autotune_cache_warm_start(self):
class Model(torch.nn.Module):
def forward(self, x, y, a, b):
return x + y, a + b

def f(x, y, a, b):
return Model()(x, y, a, b)

x = torch.randn(100, 100).cuda()
y = torch.randn(100, 100).cuda()
a = torch.randn(1000, 100).cuda()
b = torch.randn(1000, 100).cuda()
f_compiled = torch.compile(f, fullgraph=True)

with PatchCaches():
f_compiled(x, y, a, b)

self.assertEqual(global_stats.autotune_remote, Stats(2, 0, 2))
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)

# Don't reset FxGraphCache, see that it loads again
torch._dynamo.reset()
f_compiled(x, y, a, b)
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)

self.assertEqual(global_stats.autotune_remote, Stats(2, 2, 2))

# Check that the cache entries seem reasonable
for k in global_stats.autotune_remote.cache.keys():
self.assertRegex(k, r"[0-9a-z]{52}")
for k in global_stats.triton.cache.keys():
self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+")

@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
@config.patch({"fx_graph_cache": False})
Expand Down
3 changes: 3 additions & 0 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3597,6 +3597,9 @@ def __init__(self, static_autotuner: CachingAutotuner) -> None:
def result(self) -> CachingAutotuner:
assert self.reload_kernel_from_src is not None
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
self.static_autotuner.recheck_autotune_cache(
reload_kernel_from_src=self.reload_kernel_from_src
)
self.static_autotuner.precompile( # type: ignore[union-attr]
warm_cache_only=False,
reload_kernel=self.reload_kernel_from_src,
Expand Down
128 changes: 87 additions & 41 deletions torch/_inductor/runtime/triton_heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,55 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")


def check_autotune_cache(
configs: list[Config], filename: Optional[str], inductor_meta: dict[str, Any]
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
"""
Given a list of configs, checks autotune cache and return metadata
"""
autotune_cache = None
autotune_cache_info = {}
disabled = inductor_meta.get("force_disable_caches", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so if we only set remove autotune cache disabled, in this case, we wont honor that ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The individual cache settings for AutotuneCache are checked separately, here on call to AutotuneCache.create:

if (
not disabled
and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
):
configs_hash = hash_configs(configs)

autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
if autotune_cache:
if best_config := autotune_cache.read_best(inductor_meta, configs):
configs = [best_config]
autotune_cache_info["best_config"] = triton_config_to_hashable(
best_config
)
autotune_cache_info["autotune_cache_state"] = "hit"

else:
autotune_cache_info["autotune_cache_state"] = "miss"
autotune_cache_info["num_configs"] = len(configs)
if inductor_meta.get("coordinate_descent_tuning"):
autotune_cache_info["coordesc_tuning"] = True
if len(configs) == 1:
# This is the config that coordinate descent tuning started at, which
# is not the same as the final config chosen (i.e. only_config, best_config)
autotune_cache_info["coordesc_tuning_start_config"] = (
triton_config_to_hashable(configs[0])
)
else:
if len(configs) == 1:
autotune_cache_info["autotune_cache_state"] = "only 1 config"
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])

if disabled:
autotune_cache_info["autotune_cache_state"] = "force_disabled"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just moved code, but in the case of len(configs)==1 AND disabled is it intentional to overwrite the autotune_cache_state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, force disable should override all the other logging

log.debug("autotune caching is disabled by config.force_disable_caches")

return configs, autotune_cache, autotune_cache_info


class CachingAutotuner(KernelInterface):
"""
Simplified version of Triton autotuner that has no invalidation
Expand Down Expand Up @@ -298,6 +347,41 @@ def is_statically_launchable(self):
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
)

def recheck_autotune_cache(
self, reload_kernel_from_src: Callable[[], CachingAutotuner]
) -> None:
"""
On cache load on static autotuner, we need to recheck the autotune cache, since
a best config could have been found from a previous run
"""
assert self.is_statically_launchable()

configs = [result.config for result in self.compile_results]

(cached_configs, _, autotune_cache_info) = check_autotune_cache(
configs, self.filename, self.inductor_meta
)
self.autotune_cache_info = autotune_cache_info
# I.e. there was an autotune cache hit
if len(cached_configs) == 1 and len(configs) > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to check len(configs) > 1 ? we'd still coordinate_descent_tune with a single config right ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If len(configs) == 1, we don't need to prune the list of configs or compile results, so there's no need to loop through the list. If coordesc tuning is on, then we'll start coordesc tuning immediately.

best_config = cached_configs[0]
# Grab the best compiled config, if it's in the list of available ones
best_config_hash = triton_config_to_hashable(best_config)

for compile_result in self.compile_results:
if triton_config_to_hashable(compile_result.config) == best_config_hash:
Comment on lines +371 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just so I am following, if we were coordinate descent tuning, and happened to have the best config from start, then these would be equal ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct: the only case where it falls out of this for loop is if coordesc tuning finds a best_config that wasn't one of the precompiled options. In that case, we haven't saved anything in the cache and need to recompile.

self.compile_results = [compile_result]
return

# If the best config isn't in our list of compile results,
# it's likely because it was found by coordesc after the cache
# already saved
if best_config.found_by_coordesc:
with dynamo_timed("CachingAutotuner.slow_precompile_config"):
if self.fn.fn is None:
self.fn = reload_kernel_from_src().fn
Comment on lines +379 to +382
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason we need to rely on best_config.found_by_coordesc ? should this be an assert, or should we always be reloading ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be an assert yes, because if best_config isn't in the list of compiled configs it should always be because of coordesc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, I was a little scared that I might have missed a case where it's possible for best_config to be not in our list... and it doesn't seem particularly helpful to crash in that case? We should just re-autotune and pretend it was a cache miss IMO

self.compile_results = [self._precompile_config(best_config)]

def set_compile_info(
self, compile_id: Optional[CompileId], is_backward: bool
) -> None:
Expand Down Expand Up @@ -1713,47 +1797,9 @@ def cached_autotune(
assert len(configs) == 1 or filename
inductor_meta = {} if inductor_meta is None else inductor_meta

disabled = inductor_meta.get("force_disable_caches", False)

# on disk caching logic and/or remote caching
autotune_cache = None
autotune_cache_info = {}
if (
not disabled
and filename is not None
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
):
configs_hash = hash_configs(configs)

autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
if autotune_cache:
if best_config := autotune_cache.read_best(inductor_meta, configs):
configs = [best_config]
autotune_cache_info["best_config"] = triton_config_to_hashable(
best_config
)
autotune_cache_info["autotune_cache_state"] = "hit"
else:
autotune_cache_info["autotune_cache_state"] = "miss"
autotune_cache_info["num_configs"] = len(configs)
if inductor_meta.get("coordinate_descent_tuning"):
autotune_cache_info["coordesc_tuning"] = True
if len(configs) == 1:
# This is the config that coordinate descent tuning started at, which
# is not the same as the final config chosen (i.e. only_config, best_config)
autotune_cache_info["coordesc_tuning_start_config"] = (
triton_config_to_hashable(configs[0])
)
else:
if len(configs) == 1:
autotune_cache_info["autotune_cache_state"] = "only 1 config"
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])

if disabled:
autotune_cache_info["autotune_cache_state"] = "force_disabled"
log.debug("autotune caching is disabled by config.force_disable_caches")

configs, autotune_cache, autotune_cache_info = check_autotune_cache(
configs, filename, inductor_meta
)
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
optimize_mem = inductor_meta.pop("optimize_mem", True)

Expand Down
Loading
0