10000 [inductor] support dilation in max_pool2d lowering · pytorch/pytorch@18b0691 · GitHub
[go: up one dir, main page]

Skip to content

Commit 18b0691

Browse files
committed
[inductor] support dilation in max_pool2d lowering
ghstack-source-id: cf4dda7 Pull Request resolved: #148209
1 parent fbe6195 commit 18b0691

File tree

7 files changed

+64
-46
lines changed

7 files changed

+64
-46
lines changed

test/inductor/test_cpu_cpp_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ class BaseTest(NamedTuple):
246246
for func in dir(test_cpu_repro.CPUReproTests())
247247
if func.startswith("test_lstm_packed_change_input_sizes")
248248
],
249-
BaseTest("test_max_pool2d6"),
249+
BaseTest("test_max_pool2d6_dilation_1"),
250+
BaseTest("test_max_pool2d6_dilation_2"),
250251
BaseTest(
251252
"test_mkl_linear", "", test_cpu_repro.CPUReproTests(), condition=TEST_MKL
252253
),

test/inductor/test_mps_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ def fn(a):
183183
"test_lgamma",
184184
"test_linear_float64",
185185
"test_log_fp64",
186-
"test_low_memory_max_pool",
186+
"test_low_memory_max_pool_dilation_1",
187+
"test_low_memory_max_pool_dilation_2",
187188
"test_max_min",
188189
"test_max_pool2d2",
189190
"test_min_max_reduction_nan",

test/inductor/test_torchinductor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4259,22 +4259,22 @@ def fn2(a):
42594259
(torch.randn([2, 2, 10]),),
42604260
)
42614261

4262-
def test_low_memory_max_pool(self):
4262+
@parametrize("dilation", (1, 2))
4263+
def test_low_memory_max_pool(self, dilation: int):
42634264
prims = torch.ops.prims
42644265

42654266
def fn(x):
42664267
kernel_size = [3, 3]
42674268
stride = [2, 2]
42684269
padding = [1, 1]
4269-
dilation = [1, 1]
42704270
ceil_mode = False
42714271

42724272
vals, offsets = prims._low_memory_max_pool2d_with_offsets(
42734273
x,
42744274
kernel_size,
42754275
stride,
42764276
padding,
4277-
dilation,
4277+
[dilation] * 2,
42784278
ceil_mode,
42794279
)
42804280
indices = prims._low_memory_max_pool2d_offsets_to_indices(
@@ -4283,6 +4283,7 @@ def fn(x):
42834283
x.size(-1),
42844284
stride,
42854285
padding,
4286+
dilation=[dilation] * 2,
42864287
)
42874288
return vals, indices, offsets
42884289

@@ -5019,10 +5020,13 @@ def fn(x):
50195020
)
50205021

50215022
@skip_if_gpu_halide # slow
5022-
def test_max_pool2d6(self):
5023+
@parametrize("dilation", (1, 2))
5024+
def test_max_pool2d6(self, dilation: int):
50235025
# Big kernel size
50245026
def fn(x):
5025-
return aten.max_pool2d_with_indices(x, [13, 13], [])
5027+
return aten.max_pool2d_with_indices(
5028+
x, [13, 13], [], dilation=[dilation] * 2
5029+
)
50265030

50275031
self.common(
50285032
fn,
@@ -5044,16 +5048,14 @@ def fn(x):
50445048

50455049
# From https://github.com/pytorch/pytorch/issues/93384
50465050
def test_max_pool2d8(self):
5047-
# dialtion is not 1, use fallback
5051+
# dilation is not 1
50485052
def fn(x):
50495053
return aten.max_pool2d_with_indices(x, [3, 2], [2, 1], [1, 1], [1, 2])
50505054

5051-
torch._inductor.metrics.generated_kernel_count = 0
50525055
self.common(
50535056
fn,
50545057
(torch.randn([2, 2, 3, 6]),),
50555058
)
5056-
assertGeneratedKernelCountEqual(self, 0)
50575059

50585060
def test_avg_pool2d1(self):
50595061
def fn(x):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,6 @@ def run(*ex, **kwargs):
174174
"test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
175175
"test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)),
176176
"test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)),
177-
"test_max_pool2d8_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
178177
"test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure(
179178
("cpu", "cuda")
180179
),

torch/_inductor/decomposition.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -986,11 +986,9 @@ def max_pool2d_with_indices(
986986
stride = pad_listlike(stride, 2)
987987

988988
window_size = kernel_size[0] * kernel_size[1]
989-
# We fallback when using non-default dilation or when the window size is too large
989+
# We fallback when the window size is too large
990990
if (
991-
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
992-
kernel_size, dilation
993-
)
991+
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(kernel_size)
994992
or window_size > torch.iinfo(torch.int8).max
995993
):
996994
return NotImplemented
@@ -1009,6 +1007,7 @@ def max_pool2d_with_indices(
10091007
x.size(-1),
10101008
stride,
10111009
padding,
1010+
dilation,
10121011
)
10131012
return vals, indices
10141013

torch/_inductor/inductor_prims.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,21 @@ def _low_memory_max_pool2d_with_offsets_aten(
141141
ih = indices // input_width
142142
iw = indices - (ih * input_width)
143143

144-
h_inc = ih - hbase
145-
w_inc = iw - wbase
144+
h_inc = (ih - hbase) // dilation[0]
145+
w_inc = (iw - wbase) // dilation[1]
146146

147147
offsets = h_inc * kernel_width + w_inc
148148

149149
return vals, offsets.to(torch.int8)
150150

151151

152152
def _low_memory_max_pool2d_offsets_to_indices_aten(
153-
offsets, kernel_width, input_width, stride, padding
153+
offsets,
154+
kernel_width,
155+
input_width,
156+
stride,
157+
padding,
158+
dilation,
154159
):
155160
offsets = offsets.to(torch.int64)
156161
h_inc = offsets // kernel_width
@@ -171,8 +176,8 @@ def _low_memory_max_pool2d_offsets_to_indices_aten(
171176
hbase = bh * stride[0] - padding[0]
172177
wbase = bw * stride[1] - padding[1]
173178

174-
ih = hbase + h_inc
175-
iw = wbase + w_inc
179+
ih = hbase + h_inc * dilation[0]
180+
iw = wbase + w_inc * dilation[1]
176181
return ih * input_width + iw
177182

178183

@@ -184,7 +189,7 @@ def _low_memory_max_pool2d_offsets_to_indices_aten(
184189
)
185190

186191
_low_memory_max_pool2d_offsets_to_indices = make_prim(
187-
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950
192+
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", # noqa: B950
188193
_low_memory_max_pool2d_offsets_to_indices_aten,
189194
doc="Convert small int offsets to regular indices.",
190195
)

torch/_inductor/lowering.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4297,14 +4297,22 @@ def load(index):
42974297
return load
42984298

42994299

4300-
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
4300+
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None):
4301+
if dilation is None:
4302+
dilation = [1] * len(padding)
4303+
43014304
x_out = FloorDiv(
4302-
x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
4305+
x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1),
4306+
stride[i],
43034307
)
43044308

43054309
if ceil_mode:
43064310
x_alt = FloorDiv(
4307-
x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i]
4311+
x
4312+
+ 2 * padding[i]
4313+
- dilation[i] * (kernel_size[i] - 1)
4314+
+ 2 * (stride[i] - 1),
4315+
stride[i],
43084316
)
43094317
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
43104318
# Sliding windows must start within the input or left padding
@@ -4319,10 +4327,10 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
43194327
return x_out, ceil_mode
43204328

43214329

4322-
def should_fallback_max_pool2d_with_indices(kernel_size, dilation):
4330+
def should_fallback_max_pool2d_with_indices(kernel_size):
43234331
kernel_size = pad_listlike(kernel_size, 2)
43244332
window_size = kernel_size[0] * kernel_size[1]
4325-
return (window_size > 25) or any(d > 1 for d in dilation)
4333+
return window_size > 25
43264334

43274335

43284336
def max_pool2d_checks(
@@ -4347,7 +4355,7 @@ def max_pool2d_checks(
43474355
assert len(dilation) == 2
43484356
assert len(x.get_size()) in (3, 4)
43494357

4350-
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation)
4358+
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size)
43514359
if assert_fallback is not None:
43524360
assert use_fallback == assert_fallback
43534361

@@ -4365,8 +4373,12 @@ def _max_pool2d_with_offsets(
43654373
x.realize_hint()
43664374
*batch, h, w = x.get_size()
43674375

4368-
h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
4369-
w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
4376+
h_out, ceil_mode1 = pooling_size(
4377+
h, 0, kernel_size, stride, padding, ceil_mode, dilation=dilation
4378+
)
4379+
w_out, ceil_mode2 = pooling_size(
4380+
w, 1, kernel_size, stride, padding, ceil_mode, dilation=dilation
4381+
)
43704382

43714383
dtype = x.dtype
43724384
min_value = (
@@ -4376,7 +4388,14 @@ def _max_pool2d_with_offsets(
43764388
)
43774389

43784390
new_size = list(batch) + [h_out, w_out]
4379-
if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
4391+
if (
4392+
padding[0]
4393+
or padding[1]
4394+
or ceil_mode1
4395+
or ceil_mode2
4396+
or (dilation[0] > 1)
4397+
or (dilation[1] > 1)
4398+
):
43804399
x_loader = constant_boundary_condition(x, min_value, dim=2)
43814400
else:
43824401
x_loader = x.make_loader()
@@ -4386,7 +4405,10 @@ def _max_pool2d_with_offsets(
43864405
def fn_inner(idx, reduction_idx):
43874406
prefix = idx[:-dim]
43884407
bh = idx[-dim:]
4389-
ih = [bh[i] * stride[i] + reduction_idx[i] - padding[i] for i in range(dim)]
4408+
ih = [
4409+
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
4410+
for i in range(dim)
4411+
]
43904412
return x_loader([*prefix, *ih])
43914413

43924414
result = Reduction.create(
@@ -4454,7 +4476,7 @@ def _low_memory_max_pool2d_with_offsets(
44544476
prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None
44554477
)
44564478
def _low_memory_max_pool2d_offsets_to_indices(
4457-
offsets, kernel_width, input_width, stride, padding
4479+
offsets, kernel_width, input_width, stride, padding, dilation
44584480
):
44594481
# TODO: Generalize to other max pooling flavors, and arbitrary dim
44604482

@@ -4464,8 +4486,8 @@ def increments_to_index(h_inc, w_inc, bh, bw):
44644486
w_in = ops.index_expr(input_width, torch.int64)
44654487
hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64)
44664488
wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64)
4467-
ih = hbase + h_inc
4468-
iw = wbase + w_inc
4489+
ih = hbase + h_inc * ops.constant(dilation[0], torch.int64)
4490+
iw = wbase + w_inc * ops.constant(dilation[1], torch.int64)
44694491
return ih * w_in + iw
44704492

44714493
def offsets_to_indices(idx):
@@ -4485,12 +4507,6 @@ def offsets_to_indices(idx):
44854507
return indices
44864508

44874509

4488-
fallback_max_pool2d_with_indices = fallback_handler(
4489-
aten.max_pool2d_with_indices.default,
4490-
add_to_fallback_set=False,
4491-
)
4492-
4493-
44944510
# Fallback when we do not decompose to the low-memory path.
44954511
@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
44964512
def max_pool2d_with_indices(
@@ -4505,17 +4521,12 @@ def max_pool2d_with_indices(
45054521
x, kernel_size, stride, padding, dilation
45064522
)
45074523

4508-
if any(d > 1 for d in dilation):
4509-
return fallback_max_pool2d_with_indices(
4510-
x, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode
4511-
)
4512-
45134524
out, offsets = _max_pool2d_with_offsets(
45144525
x, kernel_size, stride, padding, dilation, ceil_mode
45154526
)
45164527

45174528
indices = _low_memory_max_pool2d_offsets_to_indices(
4518-
offsets, kernel_size[-1], x.shape[-1], stride, padding
4529+
offsets, kernel_size[-1], x.shape[-1], stride, padding, dilation
45194530
)
45204531

45214532
return out, indices

0 commit comments

Comments
 (0)
0