8000 Back out "Precompile triton templates (#121998)" (#123305) · pytorch/pytorch@e0c9764 · GitHub
[go: up one dir, main page]

Skip to content

Commit e0c9764

Browse files
yoyoyocmupytorchmergebot
authored andcommitted
Back out "Precompile triton templates (#121998)" (#123305)
Summary: We are reverting #121998 because the change plus search-autotune-cache led to significant compilation time increase, causing stuck job detector to trigger and then kill the training job. Test Plan: CI tests Reviewed By: nmacchioni Differential Revision: D55712203 Pull Request resolved: #123305 Approved by: https://github.com/eellison, https://github.com/nmacchioni, https://github.com/xw285cornell
1 parent 595613d commit e0c9764

File tree

4 files changed

+24
-46
lines changed

4 files changed

+24
-46
lines changed

torch/_inductor/autotune_process.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def benchmark(
502502
class TritonBenchmarkRequest(BenchmarkRequest):
503503
# Important: Instances of this class have to be serializable
504504
# across process boundaries. Do not put CUDA Tensors in here!
505+
505506
def __init__(
506507
self,
507508
kernel_name: str,
@@ -544,8 +545,6 @@ def make_run_fn(
544545
if "warmup" in inspect.signature(run_method).parameters:
545546
warmup_arg["warmup"] = False
546547

547-
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
548-
549548
if torch.version.hip and self.matrix_instr_nonkdim != 0:
550549
return functools.partial(
551550
run_method,
@@ -554,7 +553,9 @@ def make_run_fn(
554553
*self.extra_args,
555554
grid=self.grid,
556555
**warmup_arg,
557-
stream=get_raw_stream(self.output_tensor_meta.device.index),
556+
num_stages=self.num_stages,
557+
num_warps=self.num_warps,
558+
matrix_instr_nonkdim=self.matrix_instr_nonkdim,
558559
)
559560
else:
560561
return functools.partial(
@@ -564,13 +565,10 @@ def make_run_fn(
564565
*self.extra_args,
565566
grid=self.grid,
566567
**warmup_arg,
567-
stream=get_raw_stream(self.output_tensor_meta.device.index),
568+
num_stages=self.num_stages,
569+
num_warps=self.num_warps,
568570
)
569571

570-
def precompile(self):
571-
mod = PyCodeCache.load_by_key_path(self.module_cache_key, self.module_path)
572-
getattr(mod, self.kernel_name).precompile()
573-
574572
def __str__(self) -> str:
575573
return f"{self.kernel_name=}, {self.module_path=}, {self.module_cache_key=}"
576574

torch/_inductor/codegen/triton_utils.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,32 +63,6 @@ def signature_to_meta(
6363
}
6464

6565

66-
def is_unaligned_buffer(arg: TensorArg):
67-
buf_name = arg.buffer
68-
if buf_name in V.graph.graph_inputs:
69-
return not config.assume_aligned_inputs
70-
71-
if buf_name in V.graph.constants:
72-
# all constants are assumed to be aligned
73-
return False
74-
75-
if V.graph.scheduler:
76-
layout = V.graph.scheduler.get_buffer_layout(buf_name)
77-
else:
78-
buffer = V.graph.get_buffer(buf_name)
79-
# output arg
80-
if not buffer:
81-
assert buf_name == V.kernel.output_node.name
82-
layout = V.kernel.output_node.layout
83-
else:
84-
layout = buffer.get_layout()
85-
86-
if isinstance(layout, torch._inductor.ir.NonOwningLayout):
87-
return not layout.maybe_guard_aligned()
88-
else:
89-
return False
90-
91-
9266
def config_of(
9367
args: List[KernelArgType],
9468
*,
@@ -107,7 +81,9 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool:
10781
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
10882
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
10983
)
110-
return offset_aligned and not is_unaligned_buffer(x)
84+
return offset_aligned and not V.graph.scheduler.is_unaligned_buffer(
85+
x.buffer
86+
)
11187
else:
11288
return False
11389
if isinstance(x, SizeArg):

torch/_inductor/scheduler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2454,9 +2454,18 @@ def codegen(self):
24542454

24552455
self.flush()
24562456

2457-
def get_buffer_layout(self, buf_name: str) -> ir.Layout:
2457+
def is_unaligned_buffer(self, buf_name):
2458+
if buf_name in V.graph.graph_inputs:
2459+
return not config.assume_aligned_inputs
2460+
if buf_name in V.graph.constants:
2461+
# all constants are assumed to be aligned
2462+
return False
24582463
node = self.name_to_node[buf_name]
2459-
return node.node.get_layout()
2464+
layout = node.node.get_layout()
2465+
if isinstance(layout, ir.NonOwningLayout):
2466+
8000 return not layout.maybe_guard_aligned()
2467+
else:
2468+
return False
24602469

24612470

24622471
class BaseScheduling:

torch/_inductor/select_algorithm.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def __init__(
9494
grid_fn,
9595
meta,
9696
call_sizes,
97-
use_jit=False,
97+
use_jit=True,
9898
prefix_args=0,
9999
suffix_args=0,
100100
epilogue_fn=identity,
@@ -150,8 +150,8 @@ def jit_lines(self):
150150
argdefs, _, signature = self.args.python_argdefs()
151151
triton_meta = {
152152
"signature": signature_to_meta(signature, size_dtype=self.index_dtype),
153-
"device": self.output_node.get_device().index,
154-
"device_type": self.output_node.get_device().type,
153+
"device": V.graph.scheduler.current_device.index,
154+
"device_type": V.graph.scheduler.current_device.type,
155155
"constants": {},
156156
}
157157
triton_meta["configs"] = [config_of(signature)]
@@ -502,7 +502,7 @@ def generate(
502502
), TritonTemplateKernel(
503503
kernel_name=kernel_name,
504504
output_node=fake_out,
505-
use_jit=False,
505+
use_jit=True,
506506
**kernel_options,
507507
) as kernel:
508508
try:
@@ -688,10 +688,6 @@ def benchmark(self, *args, out):
688688
assert self.bmreq is not None
689689
return self.bmreq.benchmark(*args, output_tensor=out)
690690

691-
def precompile(self):
692-
assert self.bmreq is not None
693-
self.bmreq.precompile()
694-
695691
def __str__(self):
696692
return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})"
697693

@@ -832,7 +828,6 @@ def __call__(
832828

833829
# TODO(nmacchioni): remove once CI tests are fixed
834830
choices = [choice for choice in choices if choice is not None]
835-
836831
if len(choices) == 0:
837832
raise RuntimeError(
838833
"No choices to select, please consider adding ATEN into max_autotune_gemm_backends "

0 commit comments

Comments
 (0)
0