8000 [inductor] lowering for fractional_max_pool3d · pytorch/pytorch@abb5eeb · GitHub
[go: up one dir, main page]

Skip to content

Commit abb5eeb

Browse files
committed
[inductor] lowering for fractional_max_pool3d
also a lowering with a reduction for large window_sizes for fractional_max_pool2d ghstack-source-id: ee439f9 Pull Request resolved: #148630
1 parent 0b6ea0b commit abb5eeb

File tree

3 files changed

+95
-93
lines changed

3 files changed

+95
-93
lines changed

test/inductor/test_torchinductor.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4839,18 +4839,16 @@ def fn(x, samples):
48394839

48404840
@xfail_if_mps_unimplemented
48414841
def test_fractional_max_pool2d2(self):
4842-
# fallback for larger kernel size
4842+
# large kernel size without unrolling
48434843

48444844
def fn(x, samples):
48454845
return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples)
48464846

4847-
torch._inductor.metrics.generated_kernel_count = 0
48484847
self.common(
48494848
fn,
48504849
(torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)),
48514850
check_lowp=False,
48524851
)
4853-
assertGeneratedKernelCountEqual(self, 0)
48544852

48554853
@xfail_if_mps_unimplemented
48564854
def test_fractional_max_pool2d3(self):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def run(*ex, **kwargs):
144144
"test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
145145
"test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
146146
"test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
147-
"test_fractional_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
148147
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
149148
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
150149
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

torch/_inductor/lowering.py

+94-89
Original file line numberDiff line numberDiff line change
@@ -2602,7 +2602,6 @@ def is_aligned(x):
26022602
# WIP
26032603
make_fallback(aten._adaptive_avg_pool3d) # @isuruf
26042604
make_fallback(aten.adaptive_max_pool3d) # @isuruf
2605-
make_fallback(aten.fractional_max_pool3d) # @isuruf
26062605
make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet
26072606

26082607

@@ -4475,34 +4474,19 @@ def _low_memory_max_pool_with_offsets(
44754474
return result, to_dtype(offsets, torch.int8)
44764475

44774476

4478-
@register_lowering(
4479-
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4480-
)
4481-
def _low_memory_max_pool_offsets_to_indices(
4482-
offsets, kernel_size, input_size, stride, padding, dilation
4483-
):
4484-
# TODO: Generalize to other max pooling flavors
4477+
def _pool_offsets_to_indices(offsets, kernel_size, input_size, increments_to_index):
44854478
n_dim = len(kernel_size)
44864479
offsets_loader = offsets.make_loader()
4487-
4488-
def increments_to_index(dhw_inc, bh):
4489-
w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(n_dim)]
4490-
hbase = [
4491-
ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64)
4492-
for d in range(n_dim)
4493-
]
4494-
idhw = [
4495-
hbase[d] + dhw_inc[d] * ops.constant(dilation[d], torch.int64)
4496-
for d in range(n_dim)
4497-
]
4498-
return inductor_prims._flatten_index(idhw, w_in)
4480+
window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
44994481

45004482
def offsets_to_indices(idx):
4501-
bh = idx[-n_dim:]
45024483
offset = offsets_loader(idx)
4503-
k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(n_dim)]
4504-
dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const)
4505-
return increments_to_index(dhw_inc, bh)
4484+
offset_sympy = ops.indirect_indexing(offset, window_size)
4485+
reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
4486+
idhw = increments_to_index(idx, reduction_idx)
4487+
return ops.index_expr(
4488+
inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
4489+
)
45064490

45074491
indices = Pointwise.create(
45084492
device=offsets.get_device(),
@@ -4513,6 +4497,27 @@ def offsets_to_indices(idx):
45134497
return indices
45144498

45154499

4500+
@register_lowering(
4501+
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4502+
)
4503+
def _low_memory_max_pool_offsets_to_indices(
4504+
offsets, kernel_size, input_size, stride, padding, dilation
4505+
):
4506+
# TODO: Generalize to other max pooling flavors
4507+
n_dim = len(kernel_size)
4508+
4509+
def increments_to_index(idx, reduction_idx):
4510+
bh = idx[-n_dim:]
4511+
return [
4512+
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
4513+
for i in range(n_dim)
4514+
]
4515+
4516+
return _pool_offsets_to_indices(
4517+
offsets, kernel_size, input_size, increments_to_index
4518+
)
4519+
4520+
45164521
def _max_pool_with_indices(
45174522
x,
45184523
kernel_size,
@@ -5016,11 +5021,6 @@ def inner_fn_max_idx(idx):
50165021
return rv, ri
50175022

50185023

5019-
fallback_fractional_max_pool2d = fallback_handler(
5020-
aten.fractional_max_pool2d.default, add_to_fallback_set=False
5021-
)
5022-
5023-
50245024
def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
50255025
out_sz = out_sz[dim]
50265026
in_sz = in_sz[dim]
@@ -5039,80 +5039,85 @@ def load(prefix, i):
50395039
seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha)
50405040
seq_i = ops.to_dtype(seq_i, torch.int64)
50415041
mask = ops.lt(i_expr, out_sz_expr)
5042-
return ops.where(mask, seq_i, diff)
5042+
return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz))
50435043

50445044
return load
50455045

50465046

50475047
@register_lowering(aten.fractional_max_pool2d)
50485048
def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
5049-
x.realize_hint()
5050-
*batch, inp_h, inp_w = x.get_size()
5051-
kernel_h, kernel_w = kernel_size
5052-
h_out, w_out = output_size
5049+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)
50535050

5054-
if kernel_h * kernel_w >= 25:
5055-
return fallback_fractional_max_pool2d(
5056-
x, kernel_size, output_size, random_samples
5057-
)
50585051

5059-
gen_offsets_for_dim = functools.partial(
5060-
_fractional_pooling_offsets,
5061-
samples=random_samples,
5062-
in_sz=[inp_h, inp_w],
5063-
out_sz=output_size,
5064-
kernel_sz=kernel_size,
5065-
ndims=2,
5066-
)
5052+
@register_lowering(aten.fractional_max_pool3d)
5053+
def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
5054+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)
50675055

5068-
h_index_fn = gen_offsets_for_dim(dim=0)
5069-
w_index_fn = gen_offsets_for_dim(dim=1)
5070-
x_loader = x.make_loader()
50715056

5072-
def fn(idx, return_index):
5073-
*prefix, bh, bw = idx
5057+
def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
5058+
x.realize_hint()
5059+
batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]
50745060

5075-
h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h)
5076-
w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w)
5061+
with config.patch(unroll_reductions_threshold=25):
5062+
dhw_index_fn = [
5063+
_fractional_pooling_offsets(
5064+
samples=random_samples,
5065+
in_sz=inp_dhw,
5066+
out_sz=output_size,
5067+
kernel_sz=kernel_size,
5068+
ndims=n_dim,
5069+
dim=d,
5070+
)
5071+
for d in range(n_dim)
5072+
]
50775073

5078-
maxval = None
5079-
maxindex = None
5080-
for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
5081-
val = x_loader([*prefix, h_start_index + ih, w_start_index + iw])
5082-
if return_index:
5083-
index = ops.index_expr(
5084-
(h_start_index + ih) * inp_w + w_start_index + iw, torch.int64
5085-
)
5086-
if maxindex is None:
5087-
maxindex = index
5088-
else:
5089-
maxindex = ops.where(
5090-
ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex
5091-
)
5092-
if maxval is None:
5093-
maxval = val
5094-
else:
5095-
maxval = ops.maximum(val, maxval)
5096-
if return_index:
5097-
return maxindex
5098-
else:
5099-
return maxval
5074+
x_loader = x.make_loader()
51005075

5101-
new_size = list(batch) + [h_out, w_out]
5102-
rv = Pointwise.create(
5103-
device=x.get_device(),
5104-
dtype=x.get_dtype(),
5105-
inner_fn=functools.partial(fn, return_index=False),
5106-
ranges=new_size,
5107-
)
5076+
def fn_inner(idx, reduction_idx):
5077+
prefix = idx[:-n_dim]
5078+
return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
51085079

5109-
ri = Pointwise.create(
5110-
device=x.get_device(),
5111-
dtype=torch.int64,
5112-
inner_fn=functools.partial(fn, return_index=True),
5113-
ranges=new_size,
5114-
)
5115-
return rv, ri
5080+
def increments_to_index(idx, reduction_idx):
5081+
prefix = idx[:-n_dim]
5082+
bdhw = idx[-n_dim:]
5083+
return [
5084+
dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
5085+
for d in range(n_dim)
508 10000 6+
]
5087+
5088+
new_size = list(batch) + list(output_size)
5089+
dtype = x.get_dtype()
5090+
result = Reduction.create(
5091+
reduction_type="max",
5092+
input_node=x,
5093+
device=x.get_device(),
5094+
dst_dtype=dtype,
5095+
src_dtype=dtype,
5096+
inner_fn=fn_inner,
5097+
ranges=new_size,
5098+
reduction_ranges=kernel_size,
5099+
)
5100+
offsets = Reduction.create(
5101+
reduction_type="argmax",
5102+
input_node=x,
5103+
device=x.get_device(),
5104+
dst_dtype=torch.int64,
5105+
src_dtype=dtype,
5106+
inner_fn=fn_inner,
5107+
ranges=new_size,
5108+
reduction_ranges=kernel_size,
5109+
)
5110+
if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
5111+
# Only realize if reduction isn't unrolled
5112+
result.realize()
5113+
if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
5114+
# Only realize if reduction isn't unrolled
5115+
offsets.realize()
5116+
5117+
indices = _pool_offsets_to_indices(
5118+
offsets, kernel_size, x.shape, increments_to_index
5119+
)
5120+
return result, indices
51165121

51175122

51185123
@register_lowering(aten.upsample_nearest2d_backward.default)

0 commit comments

Comments
 (0)
0