8000 Revert "Fix conv2d strided prologue (#150697)" · pytorch/pytorch@caf8d9b · GitHub
[go: up one dir, main page]

Skip to content

Commit caf8d9b

Browse files
Revert "Fix conv2d strided prologue (#150697)"
This reverts commit 2e4ae2a. Reverted #150697 on behalf of https://github.com/ngimel due to breaks rocm build ([comment](#150697 (comment)))
1 parent 2d98a1c commit caf8d9b

File tree

2 files changed

+6
-43
lines changed

2 files changed

+6
-43
lines changed

test/inductor/test_max_autotune.py

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,44 +1380,6 @@ def check_code(self, code_str, num_kernels, num_allocs, num_deallocs):
13801380
"del", num_deallocs, exactly=True
13811381
).run(code_str)
13821382

1383-
@parametrize("prologue", (False, True))
1384-
def test_conv1x1_cast(self, prologue):
1385-
with torch._inductor.config.patch(prologue_fusion=prologue):
1386-
conv1x1 = (
1387-
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=1)
1388-
.to(memory_format=torch.channels_last)
1389-
.to(GPU_TYPE)
1390-
.to(dtype=torch.float16)
1391-
)
1392-
input_tensor = (
1393-
torch.randn(4, 3, 32, 32)
1394-
.contiguous(memory_format=torch.channels_last)
1395-
.to(GPU_TYPE)
1396-
)
1397-
1398-
def foo(mod, input):
1399-
return torch.nn.functional.conv2d(
1400-
input,
1401-
mod.weight.to(input.dtype),
1402-
None,
1403-
mod.stride,
1404-
mod.padding,
1405-
mod.dilation,
1406-
mod.groups,
1407-
)
1408-
1409-
with torch.no_grad():
1410-
out_eager = foo(conv1x1, input_tensor)
1411-
foo_c = torch.compile(foo)
1412-
out, code = run_and_get_code(foo_c, conv1x1, input_tensor)
1413-
1414-
FileCheck().check_not("extern_kernels.convolution").run(code[0])
1415-
if prologue:
1416-
self.check_code(
1417-
code[0], num_kernels=1, num_allocs=1, num_deallocs=2
1418-
)
1419-
self.assertEqual(out_eager, out, atol=1e-2, rtol=0)
1420-
14211383
@parametrize("sizes", ((64, 128, 256), (128, 128, 128), (63, 120, 250)))
14221384
def test_upcast(self, sizes):
14231385
M, K, N = sizes

torch/_inductor/select_algorithm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -746,10 +746,11 @@ def load_input(
746746
indices, self.range_trees[0].construct_entries(lengths)
747747
):
748748
range_tree_entry.set_name(name)
749-
750-
strided_index = sympy_dot(input_node.get_stride(), index_symbols)
751-
strided_index = self.rename_indexing(strided_index)
752-
self.body.writeline("xindex = " + texpr(strided_index))
749+
contiguous_index = sympy_dot(
750+
ir.FlexibleLayout.contiguous_strides(lengths), index_symbols
751+
)
752+
contiguous_index = self.rename_indexing(contiguous_index)
753+
self.body.writeline("xindex = " + texpr(contiguous_index))
753754

754755
xindex_range_root = self.range_trees[0].lookup(
755756
sympy.Integer(1), sympy_product(lengths)
@@ -822,7 +823,7 @@ def store(
822823

823824
output_index = self.rename_indexing(output_index)
824825

825-
if output_index == strided_index:
826+
if output_index == contiguous_index:
826827
output_index_str = "xindex"
827828
else:
828829
out_indexing = self.indexing(

0 commit comments

Comments
 (0)
0