8000 [inductor] lowering for fractional_max_pool3d · pytorch/pytorch@df02808 · GitHub
[go: up one dir, main page]

Skip to content

Commit df02808

Browse files
committed
[inductor] lowering for fractional_max_pool3d
also a lowering with a reduction for large window_sizes for fractional_max_pool2d ghstack-source-id: f095b47 Pull Request resolved: #148630
1 parent 3c8541d commit df02808

File tree

3 files changed

+97
-93
lines changed

3 files changed

+97
-93
lines changed

test/inductor/test_torchinductor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4726,18 +4726,16 @@ def fn(x, samples):
47264726
)
47274727

47284728
def test_fractional_max_pool2d2(self):
4729-
# fallback for larger kernel size
4729+
# large kernel size without unrolling
47304730

47314731
def fn(x, samples):
47324732
return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples)
47334733

4734-
torch._inductor.metrics.generated_kernel_count = 0
47354734
self.common(
47364735
fn,
47374736
(torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)),
47384737
check_lowp=False,
47394738
)
4740-
assertGeneratedKernelCountEqual(self, 0)
47414739

47424740
def test_fractional_max_pool2d3(self):
47434741
def fn(x, samples):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def run(*ex, **kwargs):
140140
"test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
141141
"test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
142142
"test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
143-
"test_fractional_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
144143
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
145144
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
146145
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

torch/_inductor/lowering.py

Lines changed: 96 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2595,7 +2595,6 @@ def is_aligned(x):
25952595
# WIP
25962596
make_fallback(aten._adaptive_avg_pool3d) # @isuruf
25972597
make_fallback(aten.adaptive_max_pool3d) # @isuruf
2598-
make_fallback(aten.fractional_max_pool3d) # @isuruf
25992598

26002599

26012600
# 1) Easy
@@ -4472,34 +4471,19 @@ def _low_memory_max_pool_with_offsets(
44724471
return result, to_dtype(offsets, torch.int8)
44734472

44744473

4475-
@register_lowering(
4476-
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4477-
)
4478-
def _low_memory_max_pool_offsets_to_indices(
4479-
offsets, kernel_size, input_size, stride, padding, dilation
4480-
):
4481-
# TODO: Generalize to other max pooling flavors
4474+
def _pool_offsets_to_indices(offsets, kernel_size, input_size, increments_to_index):
44824475
dim = len(kernel_size)
44834476
offsets_loader = offsets.make_loader()
4484-
4485-
def increments_to_index(dhw_inc, bh):
4486-
w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(dim)]
4487-
hbase = [
4488-
ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64)
4489-
for d in range(dim)
4490-
]
4491-
idhw = [
4492-
hbase[d] + dhw_inc[d] * ops.index_expr(dilation[d], torch.int64)
4493-
for d in range(dim)
4494-
]
4495-
return inductor_prims._flatten_index(idhw, w_in)
4477+
window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
44964478

44974479
def offsets_to_indices(idx):
4498-
bh = idx[-dim:]
44994480
offset = offsets_loader(idx)
4500-
k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(dim)]
4501-
dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const)
4502-
return increments_to_index(dhw_inc, bh)
4481+
offset_sympy = ops.indirect_indexing(offset, window_size)
4482+
reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
4483+
idhw = increments_to_index(idx, reduction_idx)
4484+
return ops.index_expr(
4485+
inductor_prims._flatten_index(idhw, input_size[-dim:]), torch.int64
4486+
)
45034487

45044488
indices = Pointwise.create(
45054489
device=offsets.get_device(),
@@ -4510,6 +4494,27 @@ def offsets_to_indices(idx):
45104494
return indices
45114495

45124496

4497+
@register_lowering(
4498+
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4499+
)
4500+
def _low_memory_max_pool_offsets_to_indices(
4501+
offsets, kernel_size, input_size, stride, padding, dilation
4502+
):
4503+
# TODO: Generalize to other max pooling flavors
4504+
dim = len(kernel_size)
4505+
4506+
def increments_to_index(idx, reduction_idx):
4507+
bh = idx[-dim:]
4508+
return [
4509+
bh[i] * stride[i] + reduction_idx[i] * dilation[i] - padding[i]
4510+
for i in range(dim)
4511+
]
4512+
4513+
return _pool_offsets_to_indices(
4514+
offsets, kernel_size, input_size, increments_to_index
4515+
)
4516+
4517+
45134518
def _max_pool_with_indices(
45144519
x,
45154520
kernel_size,
@@ -5013,11 +5018,6 @@ def inner_fn_max_idx(idx):
50135018
return rv, ri
50145019

50155020

5016-
fallback_fractional_max_pool2d = fallback_handler(
5017-
aten.fractional_max_pool2d.default, add_to_fallback_set=False
5018-
)
5019-
5020-
50215021
def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
50225022
out_sz = out_sz[dim]
50235023
in_sz = in_sz[dim]
@@ -5038,80 +5038,87 @@ def load(prefix, i):
50385038
i_expr,
50395039
ops.index_expr(out_sz - 1, torch.int64),
50405040
)
5041-
return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64))
5041+
return ops.indirect_indexing(
5042+
ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)),
5043+
sympy.sympify(in_sz),
5044+
)
50425045

50435046
return load
50445047

50455048

50465049
@register_lowering(aten.fractional_max_pool2d)
50475050
def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
5048-
x.realize_hint()
5049-
*batch, inp_h, inp_w = x.get_size()
5050-
kernel_h, kernel_w = kernel_size
5051-
h_out, w_out = output_size
5051+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, dim=2)
50525052

5053-
if kernel_h * kernel_w >= 25:
5054-
return fallback_fractional_max_pool2d(
5055-
x, kernel_size, output_size, random_samples
5056-
)
50575053

5058-
gen_offsets_for_dim = functools.partial(
5059-
_fractional_pooling_offsets,
5060-
samples=random_samples,
5061-
in_sz=[inp_h, inp_w],
5062-
out_sz=output_size,
5063-
kernel_sz=kernel_size,
5064-
ndims=2,
5065-
)
5054+
@register_lowering(aten.fractional_max_pool3d)
5055+
def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
5056+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, dim=3)
50665057

5067-
h_index_fn = gen_offsets_for_dim(dim=0)
5068-
w_index_fn = gen_offsets_for_dim(dim=1)
5069-
x_loader = x.make_loader()
50705058

5071-
def fn(idx, return_index):
5072-
*prefix, bh, bw = idx
5059+
def _fractional_max_pool(x, kernel_size, output_size, random_samples, dim):
5060+
x.realize_hint()
5061+
batch, inp_dhw = x.shape[:-dim], x.shape[-dim:]
50735062

5074-
h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h)
5075-
w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w)
5063+
with config.patch(unroll_reductions_threshold=25):
5064+
dhw_index_fn = [
5065+
_fractional_pooling_offsets(
5066+
samples=random_samples,
5067+
in_sz=inp_dhw,
5068+
out_sz=output_size,
5069+
kernel_sz=kernel_size,
5070+
ndims=dim,
5071+
dim=d,
5072+
)
5073+
for d in range(dim)
5074+
]
50765075

5077-
maxval = None
5078-
maxindex = None
5079-
for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
5080-
val = x_loader([*prefix, h_start_index + ih, w_start_index + iw])
5081-
if return_index:
5082-
index = ops.index_expr(
5083-
(h_start_index + ih) * inp_w + w_start_index + iw, torch.int64
5084-
)
5085-
if maxindex is None:
5086-
maxindex = index
5087-
else:
5088-
maxindex = ops.where(
5089-
ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex
5090-
)
5091-
if maxval is None:
5092-
maxval = val
5093-
else:
5094-
maxval = ops.maximum(val, maxval)
5095-
if return_index:
5096-
return maxindex
5097-
else:
5098-
return maxval
5076+
x_loader = x.make_loader()
50995077

5100-
new_size = list(batch) + [h_out, w_out]
5101-
rv = Pointwise.create(
5102-
device=x.get_device(),
5103-
dtype=x.get_dtype(),
5104-
inner_fn=functools.partial(fn, return_index=False),
5105-
ranges=new_size,
5106-
)
5078+
def fn_inner(idx, reduction_idx):
5079+
prefix = idx[:-dim]
5080+
return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
51075081

5108-
ri = Pointwise.create(
5109-
device=x.get_device(),
5110-
dtype=torch.int64,
5111-
inner_fn=functools.partial(fn, return_index=True),
5112-
ranges=new_size,
5113-
)
5114-
return rv, ri
5082+
def increments_to_index(idx, reduction_idx):
5083+
prefix = idx[:-dim]
5084+
bdhw = idx[-dim:]
5085+
return [
5086+
dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d] for d in range(dim)
5087+
]
5088+
5089+
new_size = list(batch) + list(output_size)
5090+
dtype = x.get_dtype()
5091+
result = Reduction.create(
5092+
reduction_type="max",
5093+
input_node=x,
5094+
device=x.get_device(),
5095+
dst_dtype=dtype,
5096+
src_dtype=dtype,
5097+
inner_fn=fn_inner,
5098+
ranges=new_size,
5099+
reduction_ranges=kernel_size,
5100+
)
5101+
offsets = Reduction.create(
5102+
reduction_type="argmax",
5103+
input_node=x,
5104+
device=x.get_device(),
5105+
dst_dtype=torch.int64,
5106+
src_dtype=dtype,
5107+
inner_fn=fn_inner,
5108+
ranges=new_size,
5109+
reduction_ranges=kernel_size,
5110+
)
5111+
if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
5112+
# Only realize if reduction isn't unrolled
5113+
result.realize()
5114+
if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
5115+
# Only realize if reduction isn't unrolled
5116+
offsets.realize()
5117+
5118+
indices = _pool_offsets_to_indices(
5119+
offsets, kernel_size, x.shape, increments_to_index
5120+
)
5121+
return result, indices
51155122

51165123

51175124
@register_lowering(aten.upsample_nearest2d_backward.default)

0 commit comments

Comments
 (0)
0