8000 Move prologue_supported_inputs computations to def_kernal (#150869) · pytorch/pytorch@4bcff4a · GitHub
[go: up one dir, main page]

Skip to content

Commit 4bcff4a

Browse files
laithsakkapytorchmergebot
authored andcommitted
Move prologue_supported_inputs computations to def_kernal (#150869)
This avoid replaying load_input on a cache hit on the generate_code_cache. the idea is that if a template have prologue_loads_all_inputs = True, it means that all all inputs are loaded and hence no need to replay Effect on the current benchmark on a local run on dev server. 18549985383 -> 15072230073 25697270062 -> 20738613297 Pull Request resolved: #150869 Approved by: https://github.com/eellison
1 parent 4421aee commit 4bcff4a

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

test/inductor/test_max_autotune.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ def func_test1(x, y, z, m):
11881188
cache_key, events = get_cache_key_and_events()
11891189

11901190
if not TEST_WITH_ROCM:
1191-
self.assertEqual(
1191+
self.assertExpectedInline(
11921192
remove_white_space(cache_key),
11931193
remove_white_space(
11941194
"""
@@ -1204,13 +1204,7 @@ def func_test1(x, y, z, m):
12041204

12051205
self.assertEqual(
12061206
remove_white_space(events),
1207-
remove_white_space(
1208-
"""[
1209-
('def_kernel', ['A', 'B'], {}),
1210-
('load_input', ['A', 'a', ('idx_m', 'idx_n')], {'mask': 'a_mask', 'indent_width': 8}),
1211-
('load_input', ['B', 'b', ('idx_m', 'idx_n')], {'mask': 'b_mask', 'indent_width': 8})]
1212-
"""
1213-
),
1207+
remove_white_space("""[('def_kernel', ['A', 'B'], {})]"""),
12141208
)
12151209

12161210
# Test symbolic shapes with different symbols. Will cache miss due to different symbols in inputs.
@@ -1232,7 +1226,7 @@ def func_test1(x, y, z, m):
12321226
cache_key, events = get_cache_key_and_events()
12331227

12341228
if not TEST_WITH_ROCM:
1235-
self.assertEqual(
1229+
self.assertExpectedInline(
12361230
remove_white_space(cache_key),
12371231
remove_white_space(
12381232
"""{'input_nodes': ["[[s77, s17], [s17, 1], torch.float32, device(type='cuda', index=0), 0]",
@@ -1245,16 +1239,21 @@ def func_test1(x, y, z, m):
12451239
),
12461240
)
12471241

1248-
self.assertEqual(
1242+
self.assertExpectedInline(
1243+
remove_white_space(events),
1244+
remove_white_space(
1245+
"""[('def_kernel',['A','B'],{}),('size',['A',0],{}),('size',['B',1],{}),('size',['A',1],{})]"""
1246+
),
1247+
)
1248+
self.assertExpectedInline(
12491249
remove_white_space(events),
12501250
remove_white_space(
12511251
"""[
1252-
('def_kernel', ['A', 'B'], {}),
1253-
('size', ['A', 0], {}), ('size', ['B', 1], {}),
1254-
('size', ['A', 1], {}),
1255-
('load_input', ['A', 'a', ('idx_m', 'idx_n')], {'mask': 'a_mask', 'indent_width': 8}),
1256-
('load_input', ['B', 'b', ('idx_m', 'idx_n')], {'mask': 'b_mask', 'indent_width': 8})]
1257-
"""
1252+
('def_kernel', ['A', 'B'], {}),
1253+
('size', ['A', 0], {}),
1254+
('size', ['B', 1], {}),
1255+
('size', ['A', 1], {})]
1256+
"""
12581257
),
12591258
)
12601259

torch/_inductor/kernel/mm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@
230230
"""
231231
),
232232
cache_codegen_enabled_for_template=True,
233+
prologue_loads_all_inputs=True,
233234
)
234235

235236
persistent_tma_mm_template = TritonTemplate(

torch/_inductor/select_algorithm.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def __init__(
311311
epilogue_fn=identity,
312312
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
313313
workspace_arg: Optional[WorkspaceArg] = None,
314+
prologue_loads_all_inputs=False,
314315
) -> None:
315316
numel = sympy_product(output_node.get_size())
316317
super().__init__(
@@ -387,6 +388,10 @@ def __init__(
387388
# Update each time an input is marked frozen, used to replay the freezing of inputs on a cache hit.
388389
self.frozen_layouts_cnt = 0
389390

391+
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
392+
# by adding all inputs.
393+
self.prologue_loads_all_inputs = prologue_loads_all_inputs
394+
390395
def input_dependent_preserved_state(self) -> str:
391396
# Not adding self.args.output_buffers on purpose. But we do not need to reproduce it on a cache hit.
392397
# (never accessed).
@@ -428,6 +433,7 @@ def set_subgraph_body(self, body_name: str):
428433
key.name: getattr(self, key.name)
429434
for key in dataclasses.fields(SubgraphInfo)
430435
}
436+
431437
assert body_name in self.subgraph_bodies, body_name
432438

433439
subgraph = self.subgraph_bodies[body_name]
@@ -585,10 +591,13 @@ def def_kernel(self, *argnames):
585591
# The args may be duplicated, so renaming must be after args are de-duplicated.
586592
for name in argnames:
587593
input_node = self.named_input_nodes[name]
594+
if self.prologue_loads_all_inputs:
595+
self.prologue_supported_inputs.add(input_node.get_name())
588596
if input_node.get_name() in V.graph.removed_buffers:
589597
continue
590598
if input_node.get_name() in self.prologue_fused_inputs:
591599
continue
600+
592601
arg_name = self.args.input_buffers[input_node.get_name()]
593602
if input_node.get_layout().offset == 0:
594603
renames.writeline(f"{name} = {arg_name}")
@@ -756,7 +765,9 @@ def load_input(
756765
"""
757766

758767
input_node = self.named_input_nodes[input_name]
759-
self.prologue_supported_inputs.add(input_node.get_name())
768+
if not self.prologue_loads_all_inputs:
769+
self.prologue_supported_inputs.add(input_node.get_name())
770+
760771
tilings = (sympy_product(input_node.get_size()), sympy.Integer(1))
761772
groups = {
762773
"x": tilings[0],
@@ -1261,6 +1272,7 @@ def __init__(
12611272
source: str,
12621273
debug=False,
12631274
cache_codegen_enabled_for_template=False,
1275+
prologue_loads_all_inputs=False,
12641276
) -> None:
12651277
super().__init__(name)
12661278
self.grid = grid
@@ -1271,6 +1283,9 @@ def __init__(
12711283
self._cache_codegen_enabled_for_template = cache_codegen_enabled_for_template
12721284
self._generated_code_cache: GeneratedCodeCache = GeneratedCodeCache()
12731285
clear_on_fresh_inductor_cache(self._generated_code_cache)
1286+
# When prologue_loads_all_inputs is true, prologue_supported_inputs is populated during def_kernel
1287+
# by adding all inputs.
1288+
self.prologue_loads_all_inputs = prologue_loads_all_inputs
12741289

12751290
# When this flag is on, we ensure that the cached results and the generated result if cache
12761291
# was not used are the same.
@@ -1370,6 +1385,7 @@ def generate_and_load(
13701385
"suffix_args": suffix_args,
13711386
"epilogue_fn": epilogue_fn,
13721387
"subgraphs": subgraphs,
1388+
"prologue_loads_all_inputs": self.prologue_loads_all_inputs,
13731389
}
13741390

13751391
if HAS_WARP_SPEC:

0 commit comments

Comments
 (0)
0