8000 [inductor] lowering for fractional_max_pool3d · pytorch/pytorch@6cd80b2 · GitHub
[go: up one dir, main page]

Skip to content

Commit 6cd80b2

Browse files
committed
[inductor] lowering for fractional_max_pool3d
also a lowering with a reduction for large window_sizes for fractional_max_pool2d ghstack-source-id: cb31542 Pull Request resolved: #148630
1 parent 5142081 commit 6cd80b2

File tree

2 files changed

+97
-92
lines changed

2 files changed

+97
-92
lines changed

test/inductor/test_torchinductor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4736,18 +4736,16 @@ def fn(x, samples):
47364736
)
47374737

47384738
def test_fractional_max_pool2d2(self):
4739-
# fallback for larger kernel size
4739+
# large kernel size without unrolling
47404740

47414741
def fn(x, samples):
47424742
return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples)
47434743

4744-
torch._inductor.metrics.generated_kernel_count = 0
47454744
self.common(
47464745
fn,
47474746
(torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)),
47484747
check_lowp=False,
47494748
)
4750-
assertGeneratedKernelCountEqual(self, 0)
47514749

47524750
def test_fractional_max_pool2d3(self):
47534751
def fn(x, samples):

torch/_inductor/lowering.py

Lines changed: 96 additions & 89 deletions
< 10000 td data-grid-cell-id="diff-a1b077971cddfabfa0071c5162265066e867bc07721816d95b9cbe58431c38e3-4478-4468-0" data-selected="false" role="gridcell" style="background-color:var(--diffBlob-deletionNum-bgColor, var(--diffBlob-deletion-bgColor-num));text-align:center" tabindex="-1" valign="top" class="focusable-grid-cell diff-line-number position-relative left-side">4478
Original file line numberDiff line numberDiff line change
@@ -2601,7 +2601,6 @@ def is_aligned(x):
26012601
# WIP
26022602
make_fallback(aten._adaptive_avg_pool3d) # @isuruf
26032603
make_fallback(aten.adaptive_max_pool3d) # @isuruf
2604-
make_fallback(aten.fractional_max_pool3d) # @isuruf
26052604

26062605

26072606
# 1) Easy
@@ -4464,34 +4463,19 @@ def _low_memory_max_pool_with_offsets(
44644463
return result, to_dtype(offsets, torch.int8)
44654464

44664465

4467-
@register_lowering(
4468-
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4469-
)
4470-
def _low_memory_max_pool_offsets_to_indices(
4471-
offsets, kernel_size, input_size, stride, padding, dilation
4472-
):
4473-
# TODO: Generalize to other max pooling flavors
4466+
def _pool_offsets_to_indices(offsets, kernel_size, input_size, increments_to_index):
44744467
dim = len(kernel_size)
44754468
offsets_loader = offsets.make_loader()
4476-
4477-
def increments_to_index(dhw_inc, bh):
-
w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(dim)]
4479-
hbase = [
4480-
ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64)
4481-
for d in range(dim)
4482-
]
4483-
idhw = [
4484-
hbase[d] + dhw_inc[d] * ops.index_expr(dilation[d], torch.int64)
4485-
for d in range(dim)
4486-
]
4487-
return inductor_prims._flatten_index(idhw, w_in)
4469+
window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
44884470

44894471
def offsets_to_indices(idx):
4490-
bh = idx[-dim:]
44914472
offset = offsets_loader(idx)
4492-
k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(dim)]
4493-
dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const)
4494-
return increments_to_index(dhw_inc, bh)
4473+
offset_sympy = ops.indirect_indexing(offset, window_size)
4474+
reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
4475+
idhw = increments_to_index(idx, reduction_idx)
4476+
return ops.index_expr(
4477+
inductor_prims._flatten_index(idhw, input_size[-dim:]), torch.int64
4478+
)
44954479

44964480
indices = Pointwise.create(
44974481
device=offsets.get_device(),
@@ -4502,6 +4486,27 @@ def offsets_to_indices(idx):
45024486
return indices
45034487

45044488

4489+
@register_lowering(
4490+
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4491+
)
4492+
def _low_memory_max_pool_offsets_to_indices(
4493+
offsets, kernel_size, input_size, stride, padding, dilation
4494+
):
4495+
# TODO: Generalize to other max pooling flavors
4496+
dim = len(kernel_size)
4497+
4498+
def increments_to_index(idx, reduction_idx):
4499+
bh = idx[-dim:]
4500+
return [
4501+
bh[i] * stride[i] + reduction_idx[i] * dilation[i] - padding[i]
4502+
for i in range(dim)
4503+
]
4504+
4505+
return _pool_offsets_to_indices(
4506+
offsets, kernel_size, input_size, increments_to_index
4507+
)
4508+
4509+
45054510
def _max_pool_with_indices(
45064511
x,
45074512
kernel_size,
@@ -5005,11 +5010,6 @@ def inner_fn_max_idx(idx):
50055010
return rv, ri
50065011

50075012

5008-
fallback_fractional_max_pool2d = fallback_handler(
5009-
aten.fractional_max_pool2d.default, add_to_fallback_set=False
5010-
)
5011-
5012-
50135013
def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
50145014
out_sz = out_sz[dim]
50155015
in_sz = in_sz[dim]
@@ -5030,80 +5030,87 @@ def load(prefix, i):
50305030
i_expr,
50315031
ops.index_expr(out_sz - 1, torch.int64),
50325032
)
5033-
return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64))
5033+
return ops.indirect_indexing(
5034+
ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)),
5035+
sympy.sympify(in_sz),
5036+
)
50345037

50355038
return load
50365039

50375040

50385041
@register_lowering(aten.fractional_max_pool2d)
50395042
def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
5040-
x.realize_hint()
5041-
*batch, inp_h, inp_w = x.get_size()
5042-
kernel_h, kernel_w = kernel_size
5043-
h_out, w_out = output_size
5043+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, dim=2)
50445044

5045-
if kernel_h * kernel_w >= 25:
5046-
return fallback_fractional_max_pool2d(
5047-
x, kernel_size, output_size, random_samples
5048-
)
50495045

5050-
gen_offsets_for_dim = functools.partial(
5051-
_fractional_pooling_offsets,
5052-
samples=random_samples,
5053-
in_sz=[inp_h, inp_w],
5054-
out_sz=output_size,
5055-
kernel_sz=kernel_size,
5056-
ndims=2,
5057-
)
5046+
@register_lowering(aten.fractional_max_pool3d)
5047+
def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
5048+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, dim=3)
50585049

5059-
h_index_fn = gen_offsets_for_dim(dim=0)
5060-
w_index_fn = gen_offsets_for_dim(dim=1)
5061-
x_loader = x.make_loader()
50625050

5063-
def fn(idx, return_index):
5064-
*prefix, bh, bw = idx
5051+
def _fractional_max_pool(x, kernel_size, output_size, random_samples, dim):
5052+
x.realize_hint()
5053+
batch, inp_dhw = x.shape[:-dim], x.shape[-dim:]
50655054

5066-
h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h)
5067-
w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w)
5055+
with config.patch(unroll_reductions_threshold=25):
5056+
dhw_index_fn = [
5057+
_fractional_pooling_offsets(
5058+
samples=random_samples,
5059+
in_sz=inp_dhw,
5060+
out_sz=output_size,
5061+
kernel_sz=kernel_size,
5062+
ndims=dim,
5063+
dim=d,
5064+
)
5065+
for d in range(dim)
5066+
]
50685067

5069-
maxval = None
5070-
maxindex = None
5071-
for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
5072-
val = x_loader([*prefix, h_start_index + ih, w_start_index + iw])
5073-
if return_index:
5074-
index = ops.index_expr(
5075-
(h_start_index + ih) * inp_w + w_start_index + iw, torch.int64
5076-
)
5077-
if maxindex is None:
5078-
maxindex = index
5079-
else:
5080-
maxindex = ops.where(
5081-
ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex
5082-
)
5083-
if maxval is None:
5084-
maxval = val
5085-
else:
5086-
maxval = ops.maximum(val, maxval)
5087-
if return_index:
5088-
return maxindex
5089-
else:
5090-
return maxval
5068+
x_loader = x.make_loader()
50915069

5092-
new_size = list(batch) + [h_out, w_out]
5093-
rv = Pointwise.create(
5094-
device=x.get_device(),
5095-
dtype=x.get_dtype(),
5096-
inner_fn=functools.partial(fn, return_index=False),
5097-
ranges=new_size,
5098-
)
5070+
def fn_inner(idx, reduction_idx):
5071+
prefix = idx[:-dim]
5072+
return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
50995073

5100-
ri = Pointwise.create(
5101-
device=x.get_device(),
5102-
dtype=torch.int64,
5103-
inner_fn=functools.partial(fn, return_index=True),
5104-
ranges=new_size,
5105-
)
5106-
return rv, ri
5074+
def increments_to_index(idx, reduction_idx):
5075+
prefix = idx[:-dim]
5076+
bdhw = idx[-dim:]
5077+
return [
5078+
dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d] for d in range(dim)
5079+
]
5080+
5081+
new_size = list(batch) + list(output_size)
5082+
dtype = x.get_dtype()
5083+
result = Reduction.create(
5084+
reduction_type="max",
5085+
input_node=x,
5086+
device=x.get_device(),
5087+
dst_dtype=dtype,
5088+
src_dtype=dtype,
5089+
inner_fn=fn_inner,
5090+
ranges=new_size,
5091+
reduction_ranges=kernel_size,
5092+
)
5093+
offsets = Reduction.create(
5094+
reduction_type="argmax",
5095+
input_node=x,
5096+
device=x.get_device(),
5097+
dst_dtype=torch.int64,
5098+
src_dtype=dtype,
5099+
inner_fn=fn_inner,
5100+
ranges=new_size,
5101+
reduction_ranges=kernel_size,
5102+
)
5103+
if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
5104+
# Only realize if reduction isn't unrolled
5105+
result.realize()
5106+
if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
5107+
# Only realize if reduction isn't unrolled
5108+
offsets.realize()
5109+
5110+
indices = _pool_offsets_to_indices(
5111+
offsets, kernel_size, x.shape, increments_to_index
5112+
)
5113+
return result, indices
51075114

51085115

51095116
@register_lowering(aten.upsample_nearest2d_backward.default)

0 commit comments

Comments
 (0)
0