8000 [Inductor][CPP] Fix Codegen Issue when Parallel Reduction under the vectorization by leslie-fang-intel · Pull Request #151887 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor][CPP] Fix Codegen Issue when Parallel Reduction under the vectorization #151887

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,33 @@ def fn(x):
# aten parallel.
self.common(fn, (v,), atol=5e-1, rtol=5e-1)

def test_parallel_reduction_vectorization(self):
# Fix issue: https://github.com/pytorch/pytorch/issues/151523
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3,
out_channels=16,
kernel_size=(1, 7),
stride=(2, 1),
padding=0,
)

def forward(self, x, weight):
x = self.conv(x)
x = F.hardshrink(x, lambd=0)
x = x.view(x.size(0), -1)
x = torch.mv(weight, x[0])
return x

mod = Model().eval()
x = torch.randn(2, 3, 127, 255)
weight = torch.randn(10, 254976)
# Use same criterion as test_inplace_squeeze_needed
# for parallel reduction.
self.common(mod, (x, weight), atol=5e-1, rtol=5e-1)

def test_cat_mul(self):
# https://github.com/pytorch/pytorch/issues/93365
def fn(p0, p1):
Expand Down
3 changes: 0 additions & 3 deletions test/inductor/test_flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,9 +2026,6 @@ def func(qk, b, h, q, kv):
self.assertTrue((ref - out).abs().mean() < 1e-2)

@supported_platform
@unittest.skipIf(
SKIP_UT_ON_CPU, "TODO: fix https://github.com/pytorch/pytorch/issues/151290"
)
def test_make_block_mask(self, device):
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
Expand Down
15 changes: 15 additions & 0 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -5464,6 +5464,15 @@ def max_parallel_depth(self):
num_steps = num_steps * FloorDiv(loop.size, loop.steps)
max_depth += 1

def get_simd_vec_depth(loops):
# Return the first loop level which is simd_vec
for i, loop in enumerate(loops):
if loop.simd_vec:
return i
return None

simd_vec_depth = get_simd_vec_depth(self.loops)

# When the number of steps of the first inner loop is much larger than the number of steps of
# all outer loops, change `start_depth` to the first inner loop and recalculate `max_depth`.
if (
Expand All @@ -5472,6 +5481,12 @@ def max_parallel_depth(self):
and isinstance(self.loops[max_depth].size, sympy.Integer)
and num_steps * 300
< FloorDiv(self.loops[max_depth].size, self.loops[max_depth].steps)
and not (
# Disable parallel reduction under the vec loop
simd_vec_depth is not None
and max_depth > simd_vec_depth
and self.loops[max_depth].is_reduction
)
):
start_depth = max_depth
max_depth = 0
Expand Down
4194
Loading
0