8000 Pass inductor config for static cuda launcher to workers (#153382) · pytorch/pytorch@dda2c7c · GitHub
[go: up one dir, main page]

Skip to content

Commit dda2c7c

Browse files
jamesjwupytorchmergebot
authored andcommitted
Pass inductor config for static cuda launcher to workers (#153382)
Async compile workers don't respect inductor configs generally that get changed in the middle of execution because they warm up early. StaticCudaLauncher is especially susceptible to this because it affects triton compilation without being part of the inductor meta. So we'll pass it in via extra configs on each worker run. Pull Request resolved: #153382 Approved by: https://github.com/masnesral, https://github.com/jansel
1 parent 6a28cc8 commit dda2c7c

File tree

3 files changed

+38
-10
lines changed

3 files changed

+38
-10
lines changed

test/inductor/test_static_cuda_launcher.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import random
44
import tempfile
5+
from unittest import mock
56

67
import torch
78
from torch._dynamo.device_interface import get_interface_for_device
@@ -496,6 +497,24 @@ def fn(x):
496497
compiled_result = compiled_fn(arg)
497498
self.assertEqual(eager_result, compiled_result)
498499

500+
@skipIfRocm
501+
def test_disable_static_cuda_launcher(self):
502+
@torch.compile
503+
def fn(x, y):
504+
return torch.cat(((x * 4), y + 10))
505+
506+
# Test that static cuda launcher is in fact disabled
507+
with torch._inductor.config.patch("use_static_cuda_launcher", False):
508+
x = torch.rand(20, device="cuda")
509+
y = torch.rand(20, device="cuda")
510+
with mock.patch(
511+
"torch._inductor.runtime.triton_heuristics.StaticTritonCompileResult.make_launcher"
512+
) as mocked:
513+
result = fn(x, y)
514+
mocked.assert_not_called()
515+
516+
self.assertEqual(result, torch.cat(((x * 4), y + 10)))
517+
499518

500519
if __name__ == "__main__":
501520
from torch._inductor.test_case import run_tests

torch/_inductor/async_compile.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,11 +351,15 @@ def reload_kernel_in_parent():
351351
# process pool is running, so pass them to the subprocess to reset.
352352
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
353353
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
354+
extra_config = {
355+
"use_static_cuda_launcher": torch._inductor.config.use_static_cuda_launcher
356+
}
354357

355358
task = self.process_pool().submit(
356359
_worker_compile_triton,
357360
load_kernel,
358361
extra_env,
362+
extra_config,
359363
)
360364

361365
def get_result() -> CachingAutotuner:

torch/_inductor/runtime/compile_tasks.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import warnings
99
from pathlib import Path
1010
from types import ModuleType
11-
from typing import Callable, TYPE_CHECKING
11+
from typing import Any, Callable, TYPE_CHECKING
1212

1313

1414
if TYPE_CHECKING:
@@ -48,15 +48,20 @@ def _set_triton_ptxas_path() -> None:
4848

4949

5050
def _worker_compile_triton(
51-
load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str]
51+
load_kernel: Callable[[], CachingAutotuner],
52+
extra_env: dict[str, str],
53+
extra_config: dict[str, Any],
5254
) -> tuple[CachingAutotuner, int]:
5355
_set_triton_ptxas_path()
5456
os.environ.update(extra_env)
55-
start_ns = time.time_ns()
56-
kernel = load_kernel()
57-
kernel.precompile(warm_cache_only=True)
58-
elapsed_ns = time.time_ns() - start_ns
59-
kernel.prepare_for_pickle()
60-
# We can release this memory in the compile subprocesses:
61-
linecache.clearcache()
62-
return kernel, elapsed_ns // 1000
57+
from torch._inductor import config
58+
59+
with config.patch(extra_config):
60+
start_ns = time.time_ns()
61+
kernel = load_kernel()
62+
kernel.precompile(warm_cache_only=True)
63+
elapsed_ns = time.time_ns() - start_ns
64+
kernel.prepare_for_pickle()
65+
# We can release this memory in the compile subprocesses:
66+
linecache.clearcache()
67+
return kernel, elapsed_ns // 1000

0 commit comments

Comments
 (0)
0