8000 [inductor] lowering for fractional_max_pool3d · pytorch/pytorch@fc3d111 · GitHub
[go: up one dir, main page]

Skip to content

Commit fc3d111

Browse files
committed
[inductor] lowering for fractional_max_pool3d
also a lowering with a reduction for large window_sizes for fractional_max_pool2d ghstack-source-id: 4c4fee9 Pull Request resolved: #148630
1 parent 9ebd25d commit fc3d111

File tree

3 files changed

+98
-93
lines changed

3 files changed

+98
-93
lines changed

test/inductor/test_torchinductor.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4751,18 +4751,16 @@ def fn(x, samples):
47514751
)
47524752

47534753
def test_fractional_max_pool2d2(self):
4754-
# fallback for larger kernel size
4754+
# large kernel size without unrolling
47554755

47564756
def fn(x, samples):
47574757
return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples)
47584758

4759-
torch._inductor.metrics.generated_kernel_count = 0
47604759
self.common(
47614760
fn,
47624761
(torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)),
47634762
check_lowp=False,
47644763
)
4765-
assertGeneratedKernelCountEqual(self, 0)
47664764

47674765
def test_fractional_max_pool2d3(self):
47684766
def fn(x, samples):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ def run(*ex, **kwargs):
142142
"test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
143143
"test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
144144
"test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
145-
"test_fractional_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
146145
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
147146
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
148147
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

torch/_inductor/lowering.py

Lines changed: 97 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,6 @@ def is_aligned(x):
26172617
# WIP
26182618
make_fallback(aten._adaptive_avg_pool3d) # @isuruf
26192619
make_fallback(aten.adaptive_max_pool3d) # @isuruf
2620-
make_fallback(aten.fractional_max_pool3d) # @isuruf
26212620

26222621

26232622
# 1) Easy
@@ -4494,34 +4493,19 @@ def _low_memory_max_pool_with_offsets(
44944493
return result, to_dtype(offsets, torch.int8)
44954494

44964495

4497-
@register_lowering(
4498-
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4499-
)
4500-
def _low_memory_max_pool_offsets_to_indices(
4501-
offsets, kernel_size, input_size, stride, padding, dilation
4502-
):
4503-
# TODO: Generalize to other max pooling flavors
4496+
def _pool_offsets_to_indices(offsets, kernel_size, input_size, increments_to_index):
45044497
n_dim = len(kernel_size)
45054498
offsets_loader = offsets.make_loader()
4506-
4507-
def increments_to_index(dhw_inc, bh):
4508-
w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(n_dim)]
4509-
hbase = [
4510-
ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64)
4511-
for d in range(< 57AE span class=pl-s1>n_dim)
4512-
]
4513-
idhw = [
4514-
hbase[d] + dhw_inc[d] * ops.constant(dilation[d], torch.int64)
4515-
for d in range(n_dim)
4516-
]
4517-
return inductor_prims._flatten_index(idhw, w_in)
4499+
window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
45184500

45194501
def offsets_to_indices(idx):
4520-
bh = idx[-n_dim:]
45214502
offset = offsets_loader(idx)
4522-
k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(n_dim)]
4523-
dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const)
4524-
return increments_to_index(dhw_inc, bh)
4503+
offset_sympy = ops.indirect_indexing(offset, window_size)
4504+
reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
4505+
idhw = increments_to_index(idx, reduction_idx)
4506+
return ops.index_expr(
4507+
inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
4508+
)
45254509

45264510
indices = Pointwise.create(
45274511
device=offsets.get_device(),
@@ -4532,6 +4516,27 @@ def offsets_to_indices(idx):
45324516
return indices
45334517

45344518

4519+
@register_lowering(
4520+
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4521+
)
4522+
def _low_memory_max_pool_offsets_to_indices(
4523+
offsets, kernel_size, input_size, stride, padding, dilation
4524+
):
4525+
# TODO: Generalize to other max pooling flavors
4526+
n_dim = len(kernel_size)
4527+
4528+
def increments_to_index(idx, reduction_idx):
4529+
bh = idx[-n_dim:]
4530+
return [
4531+
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
4532+
for i in range(n_dim)
4533+
]
4534+
4535+
return _pool_offsets_to_indices(
4536+
offsets, kernel_size, input_size, increments_to_index
4537+
)
4538+
4539+
45354540
def _max_pool_with_indices(
45364541
x,
45374542
kernel_size,
@@ -5035,11 +5040,6 @@ def inner_fn_max_idx(idx):
50355040
return rv, ri
50365041

50375042

5038-
fallback_fractional_max_pool2d = fallback_handler(
5039-
aten.fractional_max_pool2d.default, add_to_fallback_set=False
5040-
)
5041-
5042-
50435043
def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
50445044
out_sz = out_sz[dim]
50455045
in_sz = in_sz[dim]
@@ -5060,80 +5060,88 @@ def load(prefix, i):
50605060
i_expr,
50615061
ops.index_expr(out_sz - 1, torch.int64),
50625062
)
5063-
return ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64))
5063+
return ops.indirect_indexing(
5064+
ops.where(mask, seq_i, ops.index_expr(in_sz - kernel_sz, torch.int64)),
5065+
sympy.sympify(in_sz),
5066+
)
50645067

50655068
return load
50665069

50675070

50685071
@register_lowering(aten.fractional_max_pool2d)
50695072
def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
5070-
x.realize_hint()
5071-
*batch, inp_h, inp_w = x.get_size()
5072-
kernel_h, kernel_w = kernel_size
5073-
h_out, w_out = output_size
5073+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)
50745074

5075-
if kernel_h * kernel_w >= 25:
5076-
return fallback_fractional_max_pool2d(
5077-
x, kernel_size, output_size, random_samples
5078-
)
50795075

5080-
gen_offsets_for_dim = functools.partial(
5081-
_fractional_pooling_offsets,
5082-
samples=random_samples,
5083-
in_sz=[inp_h, inp_w],
5084-
out_sz=output_size,
5085-
kernel_sz=kernel_size,
5086-
ndims=2,
5087-
)
5076+
@register_lowering(aten.fractional_max_pool3d)
5077+
def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
5078+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)
50885079

5089-
h_index_fn = gen_offsets_for_dim(dim=0)
5090-
w_index_fn = gen_offsets_for_dim(dim=1)
5091-
x_loader = x.make_loader()
50925080

5093-
def fn(idx, return_index):
5094-
*prefix, bh, bw = idx
5081+
def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
5082+
x.realize_hint()
5083+
batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]
50955084

5096-
h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h)
5097-
w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w)
5085+
with config.patch(unroll_reductions_threshold=25):
5086+
dhw_index_fn = [
5087+
_fractional_pooling_offsets(
5088+
samples=random_samples,
5089+
in_sz=inp_dhw,
5090+
out_sz=output_size,
5091+
kernel_sz=kernel_size,
5092+
ndims=n_dim,
5093+
dim=d,
5094+
)
5095+
for d in range(n_dim)
5096+
]
50985097

5099-
maxval = None
5100-
maxindex = None
5101-
for ih, iw in itertools.product(range(kernel_si 1C6A ze[0]), range(kernel_size[1])):
5102-
val = x_loader([*prefix, h_start_index + ih, w_start_index + iw])
5103-
if return_index:
5104-
index = ops.index_expr(
5105-
(h_start_index + ih) * inp_w + w_start_index + iw, torch.int64
5106-
)
5107-
if maxindex is None:
5108-
maxindex = index
5109-
else:
5110-
maxindex = ops.where(
5111-
ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex
5112-
)
5113-
if maxval is None:
5114-
maxval = val
5115-
else:
5116-
maxval = ops.maximum(val, maxval)
5117-
if return_index:
5118-
return maxindex
5119-
else:
5120-
return maxval
5098+
x_loader = x.make_loader()
51215099

5122-
new_size = list(batch) + [h_out, w_out]
5123-
rv = Pointwise.create(
5124-
device=x.get_device(),
5125-
dtype=x.get_dtype(),
5126-
inner_fn=functools.partial(fn, return_index=False),
5127-
ranges=new_size,
5128-
)
5100+
def fn_inner(idx, reduction_idx):
5101+
prefix = idx[:-n_dim]
5102+
return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
51295103

5130-
ri = Pointwise.create(
5131-
device=x.get_device(),
5132-
dtype=torch.int64,
5133-
inner_fn=functools.partial(fn, return_index=True),
5134-
ranges=new_size,
5135-
)
5136-
return rv, ri
5104+
def increments_to_index(idx, reduction_idx):
5105+
prefix = idx[:-n_dim]
5106+
bdhw = idx[-n_dim:]
5107+
return [
5108+
dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
5109+
for d in range(n_dim)
51 10000 10+
]
5111+
5112+
new_size = list(batch) + list(output_size)
5113+
dtype = x.get_dtype()
5114+
result = Reduction.create(
5115+
reduction_type="max",
5116+
input_node=x,
5117+
device=x.get_device(),
5118+
dst_dtype=dtype,
5119+
src_dtype=dtype,
5120+
inner_fn=fn_inner,
5121+
ranges=new_size,
5122+
reduction_ranges=kernel_size,
5123+
)
5124+
offsets = Reduction.create(
5125+
reduction_type="argmax",
5126+
input_node=x,
5127+
device=x.get_device(),
5128+
dst_dtype=torch.int64,
5129+
src_dtype=dtype,
5130+
inner_fn=fn_inner,
5131+
ranges=new_size,
5132+
reduction_ranges=kernel_size,
5133+
)
5134+
if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
5135+
# Only realize if reduction isn't unrolled
5136+
result.realize()
5137+
if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
5138+
# Only realize if reduction isn't unrolled
5139+
offsets.realize()
5140+
5141+
indices = _pool_offsets_to_indices(
5142+
offsets, kernel_size, x.shape, increments_to_index
5143+
)
5144+
return result, indices
51375145

51385146

51395147
@register_lowering(aten.upsample_nearest2d_backward.default)

0 commit comments

Comments
 (0)
0