8000 Update base for Update on "[inductor][cpp] support bf16/fp16 gemm tem… · pytorch/pytorch@0dab245 · GitHub
[go: up one dir, main page]

Skip to content

Commit 0dab245

Browse files
author
Jiong Gong
committed
Update base for Update on "[inductor][cpp] support bf16/fp16 gemm template epilogue fusion"
As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows: 1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling. 2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops. 3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen. 4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality. cc voznesenskym penguinwu EikanWang Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
2 parents 1318f32 + 25447ba commit 0dab245

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

test/inductor/test_aot_inductor.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch._export
1414
import torch._inductor
15+
import torch._inductor.config
1516
import torch.nn as nn
1617
from torch._dynamo.testing import rand_strided, same
1718
from torch._dynamo.utils import counters
@@ -1313,14 +1314,19 @@ def fn(a, b, alpha=1.0):
13131314
with self.assertRaises(RuntimeError):
13141315
torch._export.aot_compile(fn, args=(a, b), kwargs={"alpha": 2.0})
13151316

1316-
so_path = torch._export.aot_compile(
1317-
torch.ops.aten.add, args=(a, b), kwargs={"alpha": 2.0}, same_signature=False
1318-
)
1319-
kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path)
1320-
res = kernel_runner.run([a, b])
1321-
self.assertTrue(isinstance(res, list))
1322-
self.assertTrue(len(res) == 1)
1323-
self.assertEqual(fn(a, b, alpha=2.0), res[0])
1317+
for simdlen in [0, None]:
1318+
with torch._inductor.config.patch({"cpp.simdlen": simdlen}):
1319+
so_path = torch._export.aot_compile(
1320+
torch.ops.aten.add,
1321+
args=(a, b),
1322+
kwargs={"alpha": 2.0},
1323+
same_signature=False,
1324+
)
1325+
kernel_runner = AOTIRunnerUtil.load_runner(self.device, so_path)
1326+
res = kernel_runner.run([a, b])
1327+
self.assertTrue(isinstance(res, list))
1328+
self.assertTrue(len(res) == 1)
1329+
self.assertEqual(fn(a, b, alpha=2.0), res[0])
13241330

13251331
def test_buffer_mutation_2(self):
13261332
class Model(torch.nn.Module):

torch/_inductor/codecache.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,11 @@ def get_include_and_linking_paths(
17171717
else:
17181718
libs = ["omp"] if config.is_fbcode() else ["gomp"]
17191719

1720+
# For AOT mode, the produced library relies on torch cpu to set grad mode
1721+
# like aoti_torch_grad_mode_set_enabled
1722+
if aot_mode and sys.platform == "linux" and not config.is_fbcode():
1723+
libs += ["torch", "torch_cpu"]
1724+
17201725
# Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690
17211726
if not config.abi_com 3CE6 patible:
17221727
libs += ["c10"]

0 commit comments

Comments
 (0)
0