8000 [Inductor][CPP] Fix Codegen Issue when Parallel Reduction under the v… · pytorch/pytorch@89a621d · GitHub
[go: up one dir, main page]

Skip to content

Commit 89a621d

Browse files
[Inductor][CPP] Fix Codegen Issue when Parallel Reduction under the vectorization
ghstack-source-id: 036aa97 Pull Request resolved: #151887
1 parent 9680016 commit 89a621d

File tree

3 files changed

+42
-3
lines changed

3 files changed

+42
-3
lines changed

test/inductor/test_cpu_repro.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -987,6 +987,33 @@ def fn(x):
987987
# aten parallel.
988988
self.common(fn, (v,), atol=5e-1, rtol=5e-1)
989989

990+
def test_parallel_reduction_vectorization(self):
991+
# Fix issue: https://github.com/pytorch/pytorch/issues/151523
992+
class Model(torch.nn.Module):
993+
def __init__(self):
994+
super().__init__()
995+
self.conv = torch.nn.Conv2d(
996+
in_channels=3,
997+
out_channels=16,
998+
kernel_size=(1, 7),
999+
stride=(2, 1),
1000+
padding=0,
1001+
)
1002+
1003+
def forward(self, x, weight):
1004+
x = self.conv(x)
1005+
x = F.hardshrink(x, lambd=0)
1006+
x = x.view(x.size(0), -1)
1007+
x = torch.mv(weight, x[0])
1008+
return x
1009+
1010+
mod = Model().eval()
1011+
x = torch.randn(2, 3, 127, 255)
1012+
weight = torch.randn(10, 254976)
1013+
# Use same criterion as test_inplace_squeeze_needed
1014+
# for parallel reduction.
1015+
self.common(mod, (x, weight), atol=5e-1, rtol=5e-1)
1016+
9901017
def test_cat_mul(self):
9911018
# https://github.com/pytorch/pytorch/issues/93365
9921019
def fn(p0, p1):

test/inductor/test_flex_attention.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,9 +2026,6 @@ def func(qk, b, h, q, kv):
20262026
self.assertTrue((ref - out).abs().mean() < 1e-2)
20272027

20282028
@supported_platform
2029-
@unittest.skipIf(
2030-
SKIP_UT_ON_CPU, "TODO: fix https://github.com/pytorch/pytorch/issues/151290"
2031-
)
20322029
def test_make_block_mask(self, device):
20332030
def causal_mask(b, h, q_idx, kv_idx):
20342031
return q_idx >= kv_idx

torch/_inductor/codegen/cpp.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5464,6 +5464,15 @@ def max_parallel_depth(self):
54645464
num_steps = num_steps * FloorDiv(loop.size, loop.steps)
54655465
max_depth += 1
54665466

5467+
def get_simd_vec_depth(loops):
5468+
# Return the first loop level which is simd_vec
5469+
for i, loop in enumerate(loops):
5470+
if loop.simd_vec:
5471+
return i
5472+
return None
5473+
5474+
simd_vec_depth = get_simd_vec_depth(self.loops)
5475+
54675476
# When the number of steps of the first inner loop is much larger than the number of steps of
54685477
# all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`.
54695478
if (
@@ -5472,6 +5481,12 @@ def max_parallel_depth(self):
54725481
and isinstance(self.loops[max_depth].size, sympy.Integer)
54735482
and num_steps * 300
54745483
< FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps)
5484+
and not (
5485+
# Disable parallel reduction under the vec loop
5486+
simd_vec_depth is not None
5487+
and max_depth > simd_vec_depth
5488+
and self.loops[max_depth].is_reduction
5489+
)
54755490
):
54765491
start_depth = max_depth
54775492
max_depth = 0

0 commit comments

Comments
 (0)
0