8000 Recheck autotune cache on static cuda launcher load · pytorch/pytorch@20aa2c8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 20aa2c8

Browse files
committed
Recheck autotune cache on static cuda launcher load
ghstack-source-id: da1275b Pull Request resolved: #153565
1 parent dda2c7c commit 20aa2c8

File tree

5 files changed

+136
-43
lines changed

5 files changed

+136
-43
lines changed

test/inductor/test_codecache.py

+46
Original file line numberDiff line numberDiff line change
@@ -2183,6 +2183,52 @@ def reset(self):
21832183
torch._dynamo.reset()
21842184
clear_inductor_caches()
21852185

2186+
@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
2187+
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
2188+
@config.patch({"fx_graph_cache": True})
2189+
@config.patch({"fx_graph_remote_cache": False})
2190+
@config.patch({"autotune_local_cache": False})
2191+
@config.patch({"autotune_remote_cache": True})
2192+
@config.patch({"bundled_autotune_remote_cache": False})
2193+
@config.patch({"max_autotune": True})
2194+
@config.patch(
2195+
{"compile_threads": 1}
2196+
) # Worker processes do not register PatchCaches() properly
2197+
def test_autotune_cache_warm_start(self):
2198+
class Model(torch.nn.Module):
2199+
def forward(self, x, y, a, b):
2200+
return x + y, a + b
2201+
2202+
def f(x, y, a, b):
2203+
return Model()(x, y, a, b)
2204+
2205+
x = torch.randn(100, 100).cuda()
2206+
y = torch.randn(100, 100).cuda()
2207+
a = torch.randn(1000, 100).cuda()
2208+
b = torch.randn(1000, 100).cuda()
2209+
f_compiled = torch.compile(f, fullgraph=True)
2210+
2211+
with PatchCaches():
2212+
f_compiled(x, y, a, b)
2213+
2214+
self.assertEqual(global_stats.autotune_remote, Stats(2, 0, 2))
2215+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
2216+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
2217+
2218+
# Don't reset FxGraphCache, see that it loads again
2219+
torch._dynamo.reset()
2220+
f_compiled(x, y, a, b)
2221+
self.assertEqual(counters["inductor"]["fxgraph_cache_miss"], 1)
2222+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
2223+
2224+
self.assertEqual(global_stats.autotune_remote, Stats(2, 2, 2))
2225+
2226+
# Check that the cache entries seem reasonable
2227+
for k in global_stats.autotune_remote.cache.keys():
2228+
self.assertRegex(k, r"[0-9a-z]{52}")
2229+
for k in global_stats.triton.cache.keys():
2230+
self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c[0-9]+")
2231+
21862232
@unittest.skipIf(not HAS_CUDA, "Requires CUDA")
21872233
@unittest.skipIf(not SM80OrLater, "Requires SM80+")
21882234
@config.patch({"fx_graph_cache": False})

third_party/onnx

Submodule onnx updated 1506 files

torch/_inductor/codecache.py

+3
Original file line numberDiff line numberDiff line change
@@ -3602,6 +3602,9 @@ def __init__(self, static_autotuner: CachingAutotuner) -> None:
36023602
def result(self) -> CachingAutotuner:
36033603
assert self.reload_kernel_from_src is not None
36043604
with dynamo_timed("StaticAutotunerFuture.warm_precompile"):
3605+
self.static_autotuner.recheck_autotune_cache(
3606+
reload_kernel_from_src=self.reload_kernel_from_src
3607+
)
36053608
self.static_autotuner.precompile( # type: ignore[union-attr]
36063609
warm_cache_only=False,
36073610
reload_kernel=self.reload_kernel_from_src,

torch/_inductor/runtime/triton_heuristics.py

+85-41
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,55 @@ def _dump_launch_params(args, kwargs, launcher, kernel_name, grid):
187187
f.write(f"{kernel_name} | {args_str} | {grid!r}\n")
188188

189189

190+
def check_autotune_cache(
191+
configs, filename, inductor_meta
192+
) -> tuple[list[Config], Optional[AutotuneCache], dict[str, Any]]:
193+
"""
194+
Given a list of configs, checks autotune cache and return metadata
195+
"""
196+
autotune_cache = None
197+
autotune_cache_info = {}
198+
disabled = inductor_meta.get("force_disable_caches", False)
199+
if (
200+
not disabled
201+
and filename is not None
202+
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
203+
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
204+
):
205+
configs_hash = hash_configs(configs)
206+
207+
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
208+
if autotune_cache:
209+
if best_config := autotune_cache.read_best(inductor_meta, configs):
210+
configs = [best_config]
211+
autotune_cache_info["best_config"] = triton_config_to_hashable(
212+
best_config
213+
)
214+
autotune_cache_info["autotune_cache_state"] = "hit"
215+
216+
else:
217+
autotune_cache_info["autotune_cache_state"] = "miss"
218+
autotune_cache_info["num_configs"] = len(configs)
219+
if inductor_meta.get("coordinate_descent_tuning"):
220+
autotune_cache_info["coordesc_tuning"] = True
221+
if len(configs) == 1:
222+
# This is the config that coordinate descent tuning started at, which
223+
# is not the same as the final config chosen (i.e. only_config, best_config)
224+
autotune_cache_info["coordesc_tuning_start_config"] = (
225+
triton_config_to_hashable(configs[0])
226+
)
227+
else:
228+
if len(configs) == 1:
229+
autotune_cache_info["autotune_cache_state"] = "only 1 config"
230+
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])
231+
232+
if disabled:
233+
autotune_cache_info["autotune_cache_state"] = "force_disabled"
234+
log.debug("autotune caching is disabled by config.force_disable_caches")
235+
236+
return configs, autotune_cache, autotune_cache_info
237+
238+
190239
class CachingAutotuner(KernelInterface):
191240
"""
192241
Simplified version of Triton autotuner that has no invalidation
@@ -298,6 +347,39 @@ def is_statically_launchable(self):
298347
isinstance(x, StaticTritonCompileResult) for x in self.compile_results
299348
)
300349

350+
def recheck_autotune_cache(self, reload_kernel_from_src) -> None:
351+
"""
352+
On cache load on static autotuner, we need to recheck the autotune cache, since
353+
a best config could have been found from a previous run
354+
"""
355+
assert self.is_statically_launchable()
356+
357+
configs = [result.config for result in self.compile_results]
358+
359+
(cached_configs, _, autotune_cache_info) = check_autotune_cache(
360+
configs, self.filename, self.inductor_meta
361+
)
362+
self.autotune_cache_info = autotune_cache_info
363+
# I.e. there was an autotune cache hit
364+
if len(cached_configs) == 1 and len(configs) > 1:
365+
best_config = cached_configs[0]
366+
# Grab the best compiled config, if it's in the list of available ones
367+
best_config_hash = triton_config_to_hashable(best_config)
368+
369+
for compile_result in self.compile_results:
370+
if triton_config_to_hashable(compile_result.config) == best_config_hash:
371+
self.compile_results = [compile_result]
372+
return
373+
374+
# If the best config isn't in our list of compile results,
375+
# it's likely because it was found by coordesc after the cache
376+
# already saved
377+
if best_config.found_by_coordesc:
378+
with dynamo_timed("CachingAutotuner.slow_precompile_config"):
379+
if self.fn.fn is None:
380+
self.fn = reload_kernel_from_src().fn
381+
self.compile_results = [self._precompile_config(best_config)]
382+
301383
def set_compile_info(
302384
self, compile_id: Optional[CompileId], is_backward: bool
303385
) -> None:
@@ -1713,47 +1795,9 @@ def cached_autotune(
17131795
assert len(configs) == 1 or filename
17141796
inductor_meta = {} if inductor_meta is None else inductor_meta
17151797

1716-
disabled = inductor_meta.get("force_disable_caches", False)
1717-
1718-
# on disk caching logic and/or remote caching
1719-
autotune_cache = None
1720-
autotune_cache_info = {}
1721-
if (
1722-
not disabled
1723-
and filename is not None
1724-
and (len(configs) > 1 or inductor_meta.get("coordinate_descent_tuning"))
1725-
and not os.environ.get("TRITON_INTERPRET", "0") == "1"
1726-
):
1727-
configs_hash = hash_configs(configs)
1728-
1729-
autotune_cache = AutotuneCache.create(inductor_meta, filename, configs_hash)
1730-
if autotune_cache:
1731-
if best_config := autotune_cache.read_best(inductor_meta, configs):
1732-
configs = [best_config]
1733-
autotune_cache_info["best_config"] = triton_config_to_hashable(
1734-
best_config
1735-
)
1736-
autotune_cache_info["autotune_cache_state"] = "hit"
1737-
else:
1738-
autotune_cache_info["autotune_cache_state"] = "miss"
1739-
autotune_cache_info["num_configs"] = len(configs)
1740-
if inductor_meta.get("coordinate_descent_tuning"):
1741-
autotune_cache_info["coordesc_tuning"] = True
1742-
if len(configs) == 1:
1743-
# This is the config that coordinate descent tuning started at, which
1744-
# is not the same as the final config chosen (i.e. only_config, best_config)
1745-
autotune_cache_info["coor B216 desc_tuning_start_config"] = (
1746-
triton_config_to_hashable(configs[0])
1747-
)
1748-
else:
1749-
if len(configs) == 1:
1750-
autotune_cache_info["autotune_cache_state"] = "only 1 config"
1751-
autotune_cache_info["only_config"] = triton_config_to_hashable(configs[0])
1752-
1753-
if disabled:
1754-
autotune_cache_info["autotune_cache_state"] = "force_disabled"
1755-
log.debug("autotune caching is disabled by config.force_disable_caches")
1756-
1798+
configs, autotune_cache, autotune_cache_info = check_autotune_cache(
1799+
configs, filename, inductor_meta
1800+
)
17571801
mutated_arg_names = inductor_meta.pop("mutated_arg_names", ())
17581802
optimize_mem = inductor_meta.pop("optimize_mem", True)
17591803

0 commit comments

Comments
 (0)
0