8000 [inductor] support dilation in max_pool2d lowering by isuruf · Pull Request #148209 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] support dilation in max_pool2d lowering #148209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/inductor/test_cpu_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,8 @@ class BaseTest(NamedTuple):
for func in dir(test_cpu_repro.CPUReproTests())
if func.startswith("test_lstm_packed_change_input_sizes")
],
BaseTest("test_max_pool2d6"),
BaseTest("test_max_pool2d6_dilation_1"),
BaseTest("test_max_pool2d6_dilation_2"),
BaseTest(
"test_mkl_linear", "", test_cpu_repro.CPUReproTests(), condition=TEST_MKL
),
Expand Down
6 changes: 3 additions & 3 deletions test/inductor/test_mps_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
# This tests basic MPS compile functionality


@instantiate_parametrized_tests
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class MPSBasicTests(TestCase):
is_dtype_supported = CommonTemplate.is_dtype_supported
common = check_model_gpu
Expand Down Expand Up @@ -192,7 +193,8 @@ def fn(a):
"test_lgamma",
"test_linear_float64",
"test_log_fp64",
"test_low_memory_max_pool",
"test_low_memory_max_pool_dilation_1",
"test_low_memory_max_pool_dilation_2",
"test_max_min",
"test_max_pool2d2",
"test_multilayer_prime_size",
Expand Down Expand Up @@ -228,8 +230,6 @@ def fn(a):
]:
setattr(MPSBasicTests, test_name, getattr(CommonTemplate, test_name))

instantiate_parametrized_tests(MPSBasicTests)

if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

Expand Down
18 changes: 10 additions & 8 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4271,22 +4271,22 @@ def fn2(a):
(torch.randn([2, 2, 10]),),
)

def test_low_memory_max_pool(self):
@parametrize("dilation", (1, 2))
def test_low_memory_max_pool(self, dilation: int):
prims = torch.ops.prims

def fn(x):
kernel_size = [3, 3]
stride = [2, 2]
padding = [1, 1]
dilation = [1, 1]
ceil_mode = False

vals, offsets = prims._low_memory_max_pool2d_with_offsets(
x,
kernel_size,
stride,
padding,
dilation,
[dilation] * 2,
ceil_mode,
)
indices = prims._low_memory_max_pool2d_offsets_to_indices(
Expand All @@ -4295,6 +4295,7 @@ def fn(x):
x.size(-1),
stride,
padding,
dilation=[dilation] * 2,
)
return vals, indices, offsets

Expand Down Expand Up @@ -5044,10 +5045,13 @@ def fn(x):
)

@skip_if_gpu_halide # slow
def test_max_pool2d6(self):
@parametrize("dilation", (1, 2))
def test_max_pool2d6(self, dilation: int):
# Big kernel size
def fn(x):
return aten.max_pool2d_with_indices(x, [13, 13], [])
return aten.max_pool2d_with_indices(
x, [13, 13], [], dilation=[dilation] * 2
)

self.common(
fn,
Expand All @@ -5069,16 +5073,14 @@ def fn(x):

# From https://github.com/pytorch/pytorch/issues/93384
def test_max_pool2d8(self):
# dialtion is not 1, use fallback
# dilation is not 1
def fn(x):
return aten.max_pool2d_with_indices(x, [3, 2], [2, 1], [1, 1], [1, 2])

torch._inductor.metrics.generated_kernel_count = 0
self.common(
fn,
(torch.randn([2, 2, 3, 6]),),
)
assertGeneratedKernelCountEqual(self, 0)

def test_avg_pool2d1(self):
def fn(x):
Expand Down
1 change: 0 additions & 1 deletion test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def run(*ex, **kwargs):
"test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)),
"test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)),
"test_max_pool2d8_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure(
("cpu", "cuda")
),
Expand Down
7 changes: 3 additions & 4 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,11 +986,9 @@ def max_pool2d_with_indices(
stride = pad_listlike(stride, 2)

window_size = kernel_size[0] * kernel_size[1]
# We fallback when using non-default dilation or when the window size is too large
# We fallback when the window size is too large
if (
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
kernel_size, dilation
)
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(kernel_size)
or window_size > torch.iinfo(torch.int8).max
):
return NotImplemented
Expand All @@ -1009,6 +1007,7 @@ def max_pool2d_with_indices(
x.size(-1),
stride,
padding,
dilation,
)
return vals, indices

Expand Down
17 changes: 11 additions & 6 deletions torch/_inductor/inductor_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,16 +152,21 @@ def _low_memory_max_pool2d_with_offsets_aten(
ih = indices // input_width
iw = indices - (ih * input_width)

h_inc = ih - hbase
w_inc = iw - wbase
h_inc = (ih - hbase) // dilation[0]
w_inc = (iw - wbase) // dilation[1]

offsets = h_inc * kernel_width + w_inc

return vals, offsets.to(torch.int8)


def _low_memory_max_pool2d_offsets_to_indices_aten(
offsets, kernel_width, input_width, stride, padding
offsets,
kernel_width,
input_width,
stride,
padding,
dilation,
):
offsets = offsets.to(torch.int64)
h_inc = offsets // kernel_width
Expand All @@ -182,8 +187,8 @@ def _low_memory_max_pool2d_offsets_to_indices_aten(
hbase = bh * stride[0] - padding[0]
wbase = bw * stride[1] - padding[1]

ih = hbase + h_inc
iw = wbase + w_inc
ih = hbase + h_inc * dilation[0]
iw = wbase + w_inc * dilation[1]
return ih * input_width + iw


Expand All @@ -195,7 +200,7 @@ def _low_memory_max_pool2d_offsets_to_indices_aten(
)

_low_memory_max_pool2d_offsets_to_indices = make_prim(
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", # noqa: B950
_low_memory_max_pool2d_offsets_to_indices_aten,
doc="Convert small int offsets to regular indices.",
)
61 changes: 36 additions & 25 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4319,14 +4319,22 @@ def load(index):
return load


def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None):
if dilation is None:
dilation = [1] * len(padding)

x_out = FloorDiv(
x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1),
stride[i],
)

if ceil_mode:
x_alt = FloorDiv(
x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i]
x
+ 2 * padding[i]
- dilation[i] * (kernel_size[i] - 1)
+ 2 * (stride[i] - 1),
stride[i],
)
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
# Sliding windows must start within the input or left padding
Expand All @@ -4341,10 +4349,10 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
return x_out, ceil_mode


def should_fallback_max_pool2d_with_indices(kernel_size, dilation):
def should_fallback_max_pool2d_with_indices(kernel_size):
kernel_size = pad_listlike(kernel_size, 2)
window_size = kernel_size[0] * kernel_size[1]
return (window_size > 25) or any(d > 1 for d in dilation)
return window_size > 25


def max_pool2d_checks(
Expand All @@ -4369,7 +4377,7 @@ def max_pool2d_checks(
assert len(dilation) == 2
assert len(x.get_size()) in (3, 4)

use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation)
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size)
if assert_fallback is not None:
assert use_fallback == assert_fallback

Expand All @@ -4387,8 +4395,12 @@ def _max_pool2d_with_offsets(
x.realize_hint()
*batch, h, w = x.get_size()

h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
h_out, ceil_mode1 = pooling_size(
h, 0, kernel_size, stride, padding, ceil_mode, dilation=dilation
)
w_out, ceil_mode2 = pooling_size(
w, 1, kernel_size, stride, padding, ceil_mode, dilation=dilation
)

dtype = x.dtype
min_value = (
Expand All @@ -4398,7 +4410,14 @@ def _max_pool2d_with_offsets(
)

new_size = list(batch) + [h_out, w_out]
if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
if (
padding[0]
or padding[1]
or ceil_mode1
or ceil_mode2
or (dilation[0] > 1)
or (dilation[1] > 1)
):
x_loader = constant_boundary_condition(x, min_value, dim=2)
else:
x_loader = x.make_loader()
Expand All @@ -4408,7 +4427,10 @@ def _max_pool2d_with_offsets(
def fn_inner(idx, reduction_idx):
prefix = idx[:-dim]
bh = idx[-dim:]
ih = [bh[i] * stride[i] + reduction_idx[i] - padding[i] for i in range(dim)]
ih = [
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
for i in range(dim)
]
return x_loader([*prefix, *ih])

result = Reduction.create(
Expand Down Expand Up @@ -4476,7 +4498,7 @@ def _low_memory_max_pool2d_with_offsets(
prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None
)
def _low_memory_max_pool2d_offsets_to_indices(
offsets, kernel_width, input_width, stride, padding
offsets, kernel_width, input_width, stride, padding, dilation
):
# TODO: Generalize to other max pooling flavors, and arbitrary dim

Expand All @@ -4486,8 +4508,8 @@ def increments_to_index(h_inc, w_inc, bh, bw):
w_in = ops.index_expr(input_width, torch.int64)
hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64)
wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64)
ih = hbase + h_inc
iw = wbase + w_inc
ih = hbase + h_inc * ops.constant(dilation[0], torch.int64)
iw = wbase + w_inc * ops.constant(dilation[1], torch.int64)
return ih * w_in + iw

def offsets_to_indices(idx):
Expand All @@ -4507,12 +4529,6 @@ def offsets_to_indices(idx):
return indices


fallback_max_pool2d_with_indices = fallback_handler(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🎉

aten.max_pool2d_with_indices.default,
add_to_fallback_set=False,
)


# Fallback when we do not decompose to the low-memory path.
@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
def max_pool2d_with_indices(
Expand All @@ -4527,17 +4543,12 @@ def max_pool2d_with_indices(
x, kernel_size, stride, padding, dilation
)

if any(d > 1 for d in dilation):
return fallback_max_pool2d_with_indices(
x, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode
)

out, offsets = _max_pool2d_with_offsets(
x, kernel_size, stride, padding, dilation, ceil_mode
)

indices = _low_memory_max_pool2d_offsets_to_indices(
offsets, kernel_size[-1], x.shape[-1], stride, padding
offsets, kernel_size[-1], x.shape[-1], stride, padding, dilation
)

return out, indices
Expand Down
Loading
0