8000 [inductor] support dilation in max_pool2d lowering · pytorch/pytorch@3e3f459 · GitHub
[go: up one dir, main page]

Skip to content

Commit 3e3f459

Browse files
committed
[inductor] support dilation in max_pool2d lowering
ghstack-source-id: 587edc2 Pull Request resolved: #148209
1 parent 5d4b5ee commit 3e3f459

File tree

7 files changed

+65
-48
lines changed

7 files changed

+65
-48
lines changed

test/inductor/test_cpu_cpp_wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ class BaseTest(NamedTuple):
246246
for func in dir(test_cpu_repro.CPUReproTests())
247247
if func.startswith("test_lstm_packed_change_input_sizes")
248248
],
249-
BaseTest("test_max_pool2d6"),
249+
BaseTest("test_max_pool2d6_dilation_1"),
250+
BaseTest("test_max_pool2d6_dilation_2"),
250251
BaseTest(
251252
"test_mkl_linear", "", test_cpu_repro.CPUReproTests(), condition=TEST_MKL
252253
),

test/inductor/test_mps_basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
# This tests basic MPS compile functionality
3535

3636

37+
@instantiate_parametrized_tests
3738
class MPSBasicTests(TestCase):
3839
is_dtype_supported = CommonTemplate.is_dtype_supported
3940
common = check_model_gpu
@@ -192,7 +193,8 @@ def fn(a):
192193
"test_lgamma",
193194
"test_linear_float64",
194195
"test_log_fp64",
195-
"test_low_memory_max_pool",
196+
"test_low_memory_max_pool_dilation_1",
197+
"test_low_memory_max_pool_dilation_2",
196198
"test_max_min",
197199
"test_max_pool2d2",
198200
"test_multilayer_prime_size",
@@ -228,8 +230,6 @@ def fn(a):
228230
]:
229231
setattr(MPSBasicTests, test_name, getattr(CommonTemplate, test_name))
230232

231-
instantiate_parametrized_tests(MPSBasicTests)
232-
233233
if __name__ == "__main__":
234234
from torch._dynamo.test_case import run_tests
235235

test/inductor/test_torchinductor.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4271,22 +4271,22 @@ def fn2(a):
42714271
(torch.randn([2, 2, 10]),),
42724272
)
42734273

4274-
def test_low_memory_max_pool(self):
4274+
@parametrize("dilation", (1, 2))
4275+
def test_low_memory_max_pool(self, dilation: int):
42754276
prims = torch.ops.prims
42764277

42774278
def fn(x):
42784279
kernel_size = [3, 3]
42794280
stride = [2, 2]
42804281
padding = [1, 1]
4281-
dilation = [1, 1]
42824282
ceil_mode = False
42834283

42844284
vals, offsets = prims._low_memory_max_pool2d_with_offsets(
42854285
x,
42864286
kernel_size,
42874287
stride,
42884288
padding,
4289-
dilation,
4289+
[dilation] * 2,
42904290
ceil_mode,
42914291
)
42924292
indices = prims._low_memory_max_pool2d_offsets_to_indices(
@@ -4295,6 +4295,7 @@ def fn(x):
42954295
x.size(-1),
42964296
stride,
42974297
padding,
4298+
dilation=[dilation] * 2,
42984299
)
42994300
return vals, indices, offsets
43004301

@@ -5044,10 +5045,13 @@ def fn(x):
50445045
)
50455046

50465047
@skip_if_gpu_halide # slow
5047-
def test_max_pool2d6(self):
5048+
@parametrize("dilation", (1, 2))
5049+
def test_max_pool2d6(self, dilation: int):
50485050
# Big kernel size
50495051
def fn(x):
5050-
return at 8000 en.max_pool2d_with_indices(x, [13, 13], [])
5052+
return aten.max_pool2d_with_indices(
5053+
x, [13, 13], [], dilation=[dilation] * 2
5054+
)
50515055

50525056
self.common(
50535057
fn,
@@ -5069,16 +5073,14 @@ def fn(x):
50695073

50705074
# From https://github.com/pytorch/pytorch/issues/93384
50715075
def test_max_pool2d8(self):
5072-
# dialtion is not 1, use fallback
5076+
# dilation is not 1
50735077
def fn(x):
50745078
return aten.max_pool2d_with_indices(x, [3, 2], [2, 1], [1, 1], [1, 2])
50755079

5076-
torch._inductor.metrics.generated_kernel_count = 0
50775080
self.common(
50785081
fn,
50795082
(torch.randn([2, 2, 3, 6]),),
50805083
)
5081-
assertGeneratedKernelCountEqual(self, 0)
50825084

50835085
def test_avg_pool2d1(self):
50845086
def fn(x):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def run(*ex, **kwargs):
176176
"test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
177177
"test_logcumsumexp_dynamic_shapes": TestFailure(("cpu",)),
178178
"test_logcumsumexp_zero_dim_dynamic_shapes": TestFailure(("cpu",)),
179-
"test_max_pool2d8_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
180179
"test_max_pool2d_with_indices_backward5_dynamic_shapes": TestFailure(
181180
("cpu", "cuda")
182181
),

torch/_inductor/decomposition.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -986,11 +986,9 @@ def max_pool2d_with_indices(
986986
stride = pad_listlike(stride, 2)
987987

988988
window_size = kernel_size[0] * kernel_size[1]
989-
# We fallback when using non-default dilation or when the window size is too large
989+
# We fallback when the window size is too large
990990
if (
991-
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(
992-
kernel_size, dilation
993-
)
991+
torch._inductor.lowering.should_fallback_max_pool2d_with_indices(kernel_size)
994992
or window_size > torch.iinfo(torch.int8).max
995993
):
996994
return NotImplemented
@@ -1009,6 +1007,7 @@ def max_pool2d_with_indices(
10091007
x.size(-1),
10101008
stride,
10111009
padding,
1010+
dilation,
10121011
)
10131012
return vals, indices
10141013

torch/_inductor/inductor_prims.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -152,16 +152,21 @@ def _low_memory_max_pool2d_with_offsets_aten(
152152
ih = indices // input_width
153153
iw = indices - (ih * input_width)
154154

155-
h_inc = ih - hbase
156-
w_inc = iw - wbase
155+
h_inc = (ih - hbase) // dilation[0]
156+
w_inc = (iw - wbase) // dilation[1]
157157

158158
offsets = h_inc * kernel_width + w_inc
159159

160160
return vals, offsets.to(torch.int8)
161161

162162

163163
def _low_memory_max_pool2d_offsets_to_indices_aten(
164-
offsets, kernel_width, input_width, stride, padding
164+
offsets,
165+
kernel_width,
166+
input_width,
167+
stride,
168+
padding,
169+
dilation,
165170
):
166171
offsets = offsets.to(torch.int64)
167172
h_inc = offsets // kernel_width
@@ -182,8 +187,8 @@ def _low_memory_max_pool2d_offsets_to_indices_aten(
182187
hbase = bh * stride[0] - padding[0]
183188
wbase = bw * stride[1] - padding[1]
184189

185-
ih = hbase + h_inc
186-
iw = wbase + w_inc
190+
ih = hbase + h_inc * dilation[0]
191+
iw = wbase + w_inc * dilation[1]
187192
return ih * input_width + iw
188193

189194

@@ -195,7 +200,7 @@ def _low_memory_max_pool2d_offsets_to_indices_aten(
195200
)
196201

197202
_low_memory_max_pool2d_offsets_to_indices = make_prim(
198-
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding) -> Tensor", # noqa: B950
203+
"_low_memory_max_pool2d_offsets_to_indices(Tensor self, SymInt kernel_w, SymInt input_w, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", # noqa: B950
199204
_low_memory_max_pool2d_offsets_to_indices_aten,
200205
doc="Convert small int offsets to regular indices.",
201206
)

torch/_inductor/lowering.py

Lines changed: 36 additions & 25 deletions
< 325D /tr>
Original file line numberDiff line numberDiff line change
@@ -4319,14 +4319,22 @@ def load(index):
43194319
return load
43204320

43214321

4322-
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
4322+
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None):
4323+
if dilation is None:
4324+
dilation = [1] * len(padding)
4325+
43234326
x_out = FloorDiv(
4324-
x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i]
4327+
x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1),
4328+
stride[i],
43254329
)
43264330

43274331
if ceil_mode:
43284332
x_alt = FloorDiv(
4329-
x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i]
4333+
x
4334+
+ 2 * padding[i]
4335+
- dilation[i] * (kernel_size[i] - 1)
4336+
+ 2 * (stride[i] - 1),
4337+
stride[i],
43304338
)
43314339
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
43324340
# Sliding windows must start within the input or left padding
@@ -4341,10 +4349,10 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode):
43414349
return x_out, ceil_mode
43424350

43434351

4344-
def should_fallback_max_pool2d_with_indices(kernel_size, dilation):
4352+
def should_fallback_max_pool2d_with_indices(kernel_size):
43454353
kernel_size = pad_listlike(kernel_size, 2)
43464354
window_size = kernel_size[0] * kernel_size[1]
4347-
return (window_size > 25) or any(d > 1 for d in dilation)
4355+
return window_size > 25
43484356

43494357

43504358
def max_pool2d_checks(
@@ -4369,7 +4377,7 @@ def max_pool2d_checks(
43694377
assert len(dilation) == 2
43704378
assert len(x.get_size()) in (3, 4)
43714379

4372-
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation)
4380+
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size)
43734381
if assert_fallback is not None:
43744382
assert use_fallback == assert_fallback
43754383

@@ -4387,8 +4395,12 @@ def _max_pool2d_with_offsets(
43874395
x.realize_hint()
43884396
*batch, h, w = x.get_size()
43894397

4390-
h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode)
4391-
w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode)
4398+
h_out, ceil_mode1 = pooling_size(
4399+
h, 0, kernel_size, stride, padding, ceil_mode, dilation=dilation
4400+
)
4401+
w_out, ceil_mode2 = pooling_size(
4402+
w, 1, kernel_size, stride, padding, ceil_mode, dilation=dilation
4403+
)
43924404

43934405
dtype = x.dtype
43944406
min_value = (
@@ -4398,7 +4410,14 @@ def _max_pool2d_with_offsets(
43984410
)
43994411

44004412
new_size = list(batch) + [h_out, w_out]
4401-
if padding[0] or padding[1] or ceil_mode1 or ceil_mode2:
4413+
if (
4414+
padding[0]
4415+
or padding[1]
4416+
or ceil_mode1
4417+
or ceil_mode2
4418+
or (dilation[0] > 1)
4419+
or (dilation[1] > 1)
4420+
):
44024421
x_loader = constant_boundary_condition(x, min_value, dim=2)
44034422
else:
44044423
x_loader = x.make_loader()
@@ -4408,7 +4427,10 @@ def _max_pool2d_with_offsets(
44084427
def fn_inner(idx, reduction_idx):
44094428
prefix = idx[:-dim]
44104429
bh = idx[-dim:]
4411-
ih = [bh[i] * stride[i] + reduction_idx[i] - padding[i] for i in range(dim)]
4430+
ih = [
4431+
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
4432+
for i in range(dim)
4433+
]
44124434
return x_loader([*prefix, *ih])
44134435

44144436
result = Reduction.create(
@@ -4476,7 +4498,7 @@ def _low_memory_max_pool2d_with_offsets(
44764498
prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None
44774499
)
44784500
def _low_memory_max_pool2d_offsets_to_indices(
4479-
offsets, kernel_width, input_width, stride, padding
4501+
offsets, kernel_width, input_width, stride, padding, dilation
44804502
):
44814503
# TODO: Generalize to other max pooling flavors, and arbitrary dim
44824504

@@ -4486,8 +4508,8 @@ def increments_to_index(h_inc, w_inc, bh, bw):
44864508
w_in = ops.index_expr(input_width, torch.int64)
44874509
hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64)
44884510
wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64)
4489-
ih = hbase + h_inc
4490-
iw = wbase + w_inc
4511+
ih = hbase + h_inc * ops.constant(dilation[0], torch.int64)
4512+
iw = wbase + w_inc * ops.constant(dilation[1], torch.int64)
44914513
return ih * w_in + iw
44924514

44934515
def offsets_to_indices(idx):
@@ -4507,12 +4529,6 @@ def offsets_to_indices(idx):
45074529
return indices
45084530

45094531

4510-
fallback_max_pool2d_with_indices = fallback_handler(
4511-
aten.max_pool2d_with_indices.default,
4512-
add_to_fallback_set=False,
4513-
)
4514-
4515-
45164532
# Fallback when we do not decompose to the low-memory path.
45174533
@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
45184534
def max_pool2d_with_indices(
@@ -4527,17 +4543,12 @@ def max_pool2d_with_indices(
45274543
x, kernel_size, stride, padding, dilation
45284544
)
45294545

4530-
if any(d > 1 for d in dilation):
4531-
return fallback_max_pool2d_with_indices(
4532-
x, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode
4533-
)
4534-
45354546
out, offsets = _max_pool2d_with_offsets(
45364547
x, kernel_size, stride, padding, dilation, ceil_mode
45374548
)
45384549

45394550
indices = _low_memory_max_pool2d_offsets_to_indices(
4540-
offsets, kernel_size[-1], x.shape[-1], stride, padding
4551+
offsets, kernel_size[-1], x.shape[-1], stride, padding, dilation
45414552
)
45424553

45434554
return out, indices

0 commit comments

Comments
 (0)
0