8000 Keep raw cubin file around in case it gets deleted underneath us (#15… · pytorch/pytorch@4976b1a · GitHub
[go: up one dir, main page]

Skip to content

Commit 4976b1a

Browse files
jamesjwupytorchmergebot
authored andcommitted
Keep raw cubin file around in case it gets deleted underneath us (#153064)
This diff hardens StaticCudaLauncher in the event a cubin file gets deleted under us. We store the raw cubin on the static cuda launcher, and reload it as needed. On cold start, this can happen if the cubin file is created by triton, and gets deleted before we can load the kernel on the parent process. We don't want to store the entire cubin both in file format and in memory for caching purposes, so we delete it before caching the data. In the unfortunate/unlikely event where we can't load/find the necessary file on warm start, skip the stored triton launcher, falling back to regular triton. This comes at a cost to worker memory, but it's not more memory than regular triton workers already take, so it should be okay. Tests: - Make test_static_cuda_launcher always delete the cubin path and reload it Fixes #153030 Pull Request resolved: #153064 Approved by: https://github.com/oulgen, https://github.com/jansel
1 parent 13bdfe6 commit 4976b1a

File tree

4 files changed

+51
-3
lines changed

4 files changed

+51
-3
lines changed

test/inductor/test_static_cuda_launcher.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ def _make_launcher(
5454
cubin_file = self.write_cubin_to_tmp(compiled_kernel)
5555
compiled_kernel._cubin_path = cubin_file
5656
result = StaticallyLaunchedCudaKernel(compiled_kernel)
57+
# Test reload cubin from raw here
58+
old_cubin_path = result.cubin_path
59+
assert old_cubin_path is not None
60+
result.cubin_path = None
61+
result.reload_cubin_from_raw(old_cubin_path)
5762
device_interface = get_interface_for_device("cuda")
5863
result.load_kernel(device_interface.current_device())
5964
return result

torch/_inductor/runtime/static_cuda_launcher.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import os
23
from typing import Any, Optional
34
from typing_extensions import Unpack
45

@@ -34,6 +35,7 @@ class StaticallyLaunchedCudaKernel:
3435

3536
def __init__(self, kernel: CompiledKernel) -> None:
3637
self.name = kernel.src.fn.__name__
38+
self.cubin_raw = kernel.asm.get("cubin", None)
3739
self.cubin_path = kernel._cubin_path
3840

3941
# Used by torch.compile to filter constants in older triton versions
@@ -87,6 +89,19 @@ def __init__(self, kernel: CompiledKernel) -> None:
8789 8000
"Static cuda launcher only supports num_ctas == 1"
8890
)
8991

92+
def reload_cubin_from_raw(self, filepath: str) -> str:
93+
"""
94+
If the cubin file triton generated gets deleted under us, we can
95+
reload it from the raw cubin file.
96+
"""
97+
if self.cubin_path is None:
98+
assert self.cubin_raw is not None
99+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
100+
with open(filepath, "wb") as f:
101+
f.write(self.cubin_raw)
102+
self.cubin_path = filepath
103+
return self.cubin_path
104+
90105
def load_kernel(self, device: int) -> None:
91106
from torch._C import _StaticCudaLauncher
92107

@@ -100,6 +115,7 @@ def load_kernel(self, device: int) -> None:
100115
)
101116
# Don't need the cubin path anymore now that we've loaded
102117
self.cubin_path = None
118+
self.cubin_raw = None
103119

104120
@staticmethod
105121
@functools.lru_cache

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,16 @@ def prepare_for_pickle(self) -> tuple[Any, Any, Any, Any, Any]:
506506
self.launchers = []
507507
return old_values
508508

509+
def prepare_for_caching(self) -> None:
510+
"""
511+
Statically Launched CUDA Kernels have a raw cubin on them
512+
that we don't need to store in the cache(since TritonBundler handles the collection for us)
513+
"""
514+
for result in self.compile_results:
515+
if isinstance(result, StaticTritonCompileResult):
516+
# Don't save this in the inductor cache, as it is very large
517+
result.kernel.cubin_raw = None
518+
509519
def __getstate__(self) -> dict[str, Any]:
510520
assert not self.launchers, (
511521
"pickle should not be called with after make_launchers()"
@@ -1268,9 +1278,13 @@ def reload_cubin_path(self):
12681278
f"{self.kernel.name}.cubin",
12691279
)
12701280
if not os.path.exists(cubin_location):
1271-
raise RuntimeError(
1272-
"Cubin file saved by TritonBundler not found at %s", cubin_location
1273-
)
1281+
if self.kernel.cubin_raw is not None:
1282+
# We saved the raw cubin, so write it to he appropriate location
1283+
self.kernel.reload_cubin_from_raw(cubin_location)
1284+
else:
1285+
raise RuntimeError(
1286+
"Cubin file saved by TritonBundler not found at %s", cubin_location
1287+
)
12741288
self.kernel.cubin_path = cubin_location
12751289

12761290
def make_launcher(self) -> LauncherType:

torch/_inductor/triton_bundler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ def put_static_autotuner(cls, key: str, kernel: "CachingAutotuner") -> None: #
173173
# for FXGraphCache
174174
old_values = kernel.prepare_for_pickle()
175175
new_kernel = copy.deepcopy(kernel)
176+
new_kernel.prepare_for_caching()
176177
new_kernel._reload_kernel = None
178+
177179
entries.append(
178180
StaticallyLaunchedAutotuner(
179181
key,
@@ -223,6 +225,17 @@ def load_autotuners(
223225
kernel_names = []
224226
with dynamo_timed("TritonBundler.load_cached_static_autotuners"):
225227
for result in static_autotuners:
228+
try:
229+
# Make sure the cubin path exists and is valid
230+
for compile_result in result.kernel.compile_results:
231+
compile_result.reload_cubin_path()
232+
except RuntimeError as e:
233+
log.warning(
234+
"Failed to reload cubin file statically launchable autotuner %s: %s",
235+
result.kernel_name,
236+
e,
237+
)
238+
continue
226239
# We make a future instead of returning the kernel here so that
227240
# kernels that are not statically launchable (i.e. cache miss)
228241
# can launch a worker without waiting on the blocking step of

0 commit comments

Comments
 (0)
0