8000 [Inductor] Fix 3D tiling with permute (#147249) · pytorch/pytorch@1677a31 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1677a31

Browse files
blaine-risterpytorchmergebot
authored andcommitted
[Inductor] Fix 3D tiling with permute (#147249)
This PR adds a test case and tiny fix for 3D tiling. Before this PR, tiling would crash because one of the candidates lacked a `"y"` dimension. Now, when we're calculating 3D tiling candidates, we assume the y size is 1 if it's missing. The test case implements a 3D permute using block pointers. ``` @triton.jit def triton_poi_fused_add_0(in_ptr0, out_ptr0, znumel, ynumel, xnumel, ZBLOCK : tl.constexpr, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr): znumel = 51 ynumel = 51 xnumel = 51 zoffset = tl.program_id(2) * ZBLOCK zindex = zoffset + tl.arange(0, ZBLOCK)[None, None, :] zmask = zindex < znumel yoffset = tl.program_id(1) * YBLOCK yindex = yoffset + tl.arange(0, YBLOCK)[None, :, None] ymask = yindex < ynumel xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:, None, None] xmask = xindex < xnumel x2 = xindex y1 = yindex z0 = zindex tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[51, 51, 51], strides=[1, 51, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), boundary_check=[0, 1, 2]) tmp1 = tl.load(tl.make_block_ptr(in_ptr0, shape=[51, 51, 51], strides=[51, 1, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), boundary_check=[0, 1, 2]) tmp2 = tmp0 + tmp1 tmp3 = tmp0 + tmp0 tmp4 = tmp2 + tmp3 tl.store(tl.make_block_ptr(out_ptr0, shape=[51, 51, 51], strides=[1, 51, 2601], block_shape=[XBLOCK, YBLOCK, ZBLOCK], order=[2, 1, 0], offsets=[xoffset, yoffset, zoffset]), tl.broadcast_to(tmp4, [XBLOCK, YBLOCK, ZBLOCK]).to(tl.float32), boundary_check=[0, 1, 2]) ``` Pull Request resolved: #147249 Approved by: https://github.com/jansel
1 parent 44ee9ca commit 1677a31

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from torch.testing._internal.inductor_utils import (
2121
GPU_TYPE,
2222
HAS_GPU,
23+
requires_gpu,
2324
skip_windows_ci,
2425
TRITON_HAS_CPU,
2526
)
@@ -895,6 +896,34 @@ def func(x, y):
895896
)
896897
self.assertTrue("Min" not in code[0])
897898

899+
@requires_gpu() # FIXME this test failed on Triton-CPU
900+
def test_3d_permute_tiling(self):
901+
"""
902+
Test 3D tiling with permute.
903+
"""
904+
905+
def foo(x, y, z):
906+
dims = [0, 2, 1]
907+
a = x.permute(dims=dims) + y
908+
b = (z + y).permute(dims=dims)
909+
return a + b
910+
911+
inps = (torch.rand((51, 51, 51), device=self.device, dtype=torch.float32),) * 3
912+
result, (code,) = run_and_compare(
913+
self,
914+
foo,
915+
*inps,
916+
expected_num_triton_kernels=1,
917+
expected_num_block_pointers=3,
918+
config_patches={
919+
"triton.max_tiles": 3,
920+
"triton.prefer_nd_tiling": True,
921+
},
922+
)
923+
924+
# Check for 3D tiling
925+
self.assertIn("ZBLOCK", code)
926+
898927

899928
@unittest.skipIf(not TRITON_HAS_CPU, "requires triton CPU backend")
900929
@config.patch(cpu_backend="triton")

torch/_inductor/codegen/simd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2011,8 +2011,8 @@ def select_tiling(
20112011
def convert_tiling_to_3d(
20122012
tiling0: dict[str, sympy.Expr], tiling1: dict[str, sympy.Expr]
20132013
) -> Optional[dict[str, sympy.Expr]]:
2014-
a0, a1 = tiling0["x"], tiling0["y"]
2015-
b0, b1 = tiling1["x"], tiling1["y"]
2014+
a0, a1 = tiling0["x"], tiling0.get("y", 1)
2015+
b0, b1 = tiling1["x"], tiling1.get("y", 1)
20162016
if V.graph.sizevars.size_hint(a1 - b1) == 0:
20172017
return None
20182018
if V.graph.sizevars.size_hint(a1 - b1) < 0:

0 commit comments

Comments
 (0)
0