8000 [Inductor] Record Triton’s Base32 Cache Key in .best_config for Debug… · pytorch/pytorch@fc6e37c · GitHub
[go: up one dir, main page]

Skip to content

Commit fc6e37c

Browse files
fulvius31pytorchmergebot
authored andcommitted
[Inductor] Record Triton’s Base32 Cache Key in .best_config for Debugging (#148981)
This is a follow-up PR of the reverted one #147019 : Modified TorchInductor’s autotuning flow so that each best_config JSON file also includes the Triton “base32” (or base64) cache key. Motivation Debugging & Analysis: With this change, we can quickly identify which compiled binary and IRs belongs to a given best config. The impact is minimal since it is only an extra field in .best_config. It can help advanced performance tuning or kernel-level debugging. Also, since Triton already stores cubin/hsaco in its cache, developers/researchers can avoid to set store_cubin = True since they can get the cubin/hsaco in the Triton cache and with the code provided in this PR, they can easily match the best_config with the right Triton cache directory for the "best" kernel. Pull Request resolved: #148981 Approved by: https://github.com/davidberard98
1 parent 0413358 commit fc6e37c

File tree

4 files changed

+111
-2
lines changed

4 files changed

+111
-2
lines changed

test/inductor/test_best_config.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Owner(s): ["module: inductor"]
2+
3+
import glob
4+
import json
5+
import os
6+
import sys
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch._inductor import config
12+
from torch.testing._internal.common_utils import IS_LINUX, skipIfXpu
13+
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
14+
15+
16+
try:
17+
import triton # noqa: F401
18+
except ImportError as e:
19+
if __name__ == "__main__":
20+
sys.exit(0)
21+
raise unittest.SkipTest("requires triton") from e
22+
23+
from torch._inductor.test_case import run_tests, TestCase
24+
25+
26+
def trivial_kernel(x):
27+
return torch.sin(x) + torch.cos(x)
28+
29+
30+
class TestKernelBestConfig(TestCase):
31+
device_type = GPU_TYPE
32+
33+
@classmethod
34+
def setUpClass(cls):
35+
# Save the original configuration and environment variables.
36+
cls.original_compile_threads = config.compile_threads
37+
cls.original_max_autotune = config.max_autotune
38+
cls.original_inductor_env = os.environ.get("TORCHINDUCTOR_CACHE_DIR", "")
39+
cls.original_triton_env = os.environ.get("TRITON_CACHE_DIR", "")
40+
super().setUpClass()
41+
42+
@classmethod
43+
def tearDownClass(cls):
44+
# Restore the original configuration and environment variables.
45+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cls.original_inductor_env
46+
os.environ["TRITON_CACHE_DIR"] = cls.original_triton_env
47+
config.compile_threads = cls.original_compile_threads
48+
config.max_autotune = cls.original_max_autotune
49+
super().tearDownClass()
50+
51+
@skipIfXpu
52+
def test_best_config_has_triton_cache_key(self):
53+
with tempfile.TemporaryDirectory() as tmpdir:
54+
os.environ["TORCHINDUCTOR_CACHE_DIR"] = tmpdir
55+
triton_cache_dir = os.path.join(tmpdir, "triton_cache")
56+
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir
57+
58+
config.compile_threads = 0
59+
config.max_autotune = True
60+
61+
compiled_fn = torch.compile(trivial_kernel)
62+
63+
x = torch.randn(32, 10, device=GPU_TYPE)
64+
compiled_fn(x)
65+
66+
# Search for .best_config files in the inductor cache directory.
67+
best_config_files = glob.glob(
68+
os.path.join(tmpdir, "**", "*.best_config"), recursive=True
69+
)
70+
self.assertGreater(
71+
len(best_config_files),
72+
0,
73+
f"No best_config files found in {tmpdir}. Directory contents: {os.listdir(tmpdir)}",
74+
)
75+
76+
# Validate that each best_config file contains a real triton_cache_hash,
77+
# and that a corresponding Triton cache directory exists.
78+
for file_path in best_config_files:
79+
with open(file_path) as f:
80+
data = json.load(f)
81+
self.assertIn(
82+
"triton_cache_hash",
83+
data,
84+
f"Missing triton_cache_hash in {os.path.basename(file_path)}",
85+
)
86+
cache_hash = data["triton_cache_hash"]
87+
expected_path = os.path.join(triton_cache_dir, cache_hash)
88+
self.assertTrue(
89+
os.path.exists(expected_path),
90+
f"Triton cache directory missing: {expected_path}",
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
if IS_LINUX and HAS_GPU:
96+
run_tests()

test/run_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def __contains__(self, item):
212212
"test_unary_ufuncs",
213213
# these tests fail when cuda is not 341A available
214214
"inductor/test_aot_inductor",
215+
"inductor/test_best_config",
215216
"inductor/test_cudacodecache",
216217
"inductor/test_inductor_utils",
217218
"inductor/test_inplacing_pass",

torch/_inductor/runtime/autotune_cache.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,11 @@ def __setstate__(self, state: dict[str, Any]) -> None:
214214

215215
# Save the config in the caches
216216
def save(
217-
self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False
217+
self,
218+
config: Config,
219+
time_taken_ns: int,
220+
found_by_coordesc: bool = False,
221+
triton_cache_hash: Optional[str] = None,
218222
) -> None:
219223
data = {
220224
**config.kwargs,
@@ -223,6 +227,7 @@ def save(
223227
"configs_hash": self.configs_hash,
224228
"found_by_coordesc": found_by_coordesc,
225229
"time_taken_ms": time_taken_ns // 1000000, # Convert from NS to MS
230+
"triton_cache_hash": triton_cache_hash,
226231
}
227232
if HAS_WARP_SPEC:
228233
data.update(
@@ -485,6 +490,8 @@ def _load_cached_autotuning(
485490
# Remove time taken for comparison
486491
best_config.pop("time_taken_ms", None)
487492

493+
best_config.pop("triton_cache_hash", None)
494+
488495
if inductor_meta.get("coordinate_descent_tuning") and best_config.pop(
489496
"found_by_coordesc", False
490497
):

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,7 +871,11 @@ def autotune_to_one_config(self, *args, **kwargs):
871871
)
872872

873873
if self.save_cache_hook:
874-
self.save_cache_hook(launcher.config, self.autotune_time_taken_ns)
874+
self.save_cache_hook(
875+
launcher.config,
876+
self.autotune_time_taken_ns,
877+
triton_cache_hash=launcher.cache_hash,
878+
)
875879

876880
def save_gpu_kernel(self, stream, launcher):
877881
key = self.inductor_meta.get("kernel_name", None) # unique kernel name
@@ -1499,6 +1503,7 @@ def make_launcher(self) -> LauncherType:
14991503
launcher.n_regs = getattr(binary, "n_regs", None)
15001504
launcher.n_spills = getattr(binary, "n_spills", None)
15011505
launcher.shared = binary_shared
1506+
launcher.cache_hash = triton_hash_to_path_key(binary.hash)
15021507
launcher.store_cubin = self.inductor_meta.get("store_cubin", False)
15031508
# store this global variable to avoid the high overhead of reading it when calling run
15041509
if launcher.store_cubin:

0 commit comments

Comments
 (0)
0