8000 [inductor] lowering for fractional_max_pool3d · pytorch/pytorch@591b17d · GitHub
[go: up one dir, main page]

Skip to content

Commit 591b17d

Browse files
committed
[inductor] lowering for fractional_max_pool3d
also a lowering with a reduction for large window_sizes for fractional_max_pool2d ghstack-source-id: b228249 Pull Request resolved: #148630
1 parent 48bfe9a commit 591b17d

File tree

4 files changed

+105
-93
lines changed

4 files changed

+105
-93
lines changed

test/inductor/test_torchinductor.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -4859,18 +4859,16 @@ def fn(x, samples):
48594859

48604860
@xfail_if_mps_unimplemented
48614861
def test_fractional_max_pool2d2(self):
4862-
# fallback for larger kernel size
4862+
# large kernel size without unrolling
48634863

48644864
def fn(x, samples):
48654865
return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples)
48664866

4867-
torch._inductor.metrics.generated_kernel_count = 0
48684867
self.common(
48694868
fn,
48704869
(torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)),
48714870
check_lowp=False,
48724871
)
4873-
assertGeneratedKernelCountEqual(self, 0)
48744872

48754873
@xfail_if_mps_unimplemented
48764874
def test_fractional_max_pool2d3(self):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,6 @@ def run(*ex, **kwargs):
144144
"test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
145145
"test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
146146
"test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
147-
"test_fractional_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
148147
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
149148
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
150149
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

torch/_inductor/codegen/mps.py

+2
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def _print_FloorToInt(self, expr: sympy.Expr) -> str:
144144
x = self.doprint(expr.args[0])
145145
return f"static_cast<int>(metal::floor({x}))"
146146

147+
_print_floor = _print_FloorToInt
148+
147149
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
148150
assert len(expr.args) == 1
149151
x = self.doprint(expr.args[0])

torch/_inductor/lowering.py

+102-89
Original file line numberDiff line numberDiff line change
@@ -2600,7 +2600,6 @@ def is_aligned(x):
26002600
# WIP
26012601
make_fallback(aten._adaptive_avg_pool3d) # @isuruf
26022602
make_fallback(aten.adaptive_max_pool3d) # @isuruf
2603-
make_fallback(aten.fractional_max_pool3d) # @isuruf
26042603
make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet
26052604

26062605

@@ -4472,34 +4471,27 @@ def _low_memory_max_pool_with_offsets(
44724471
return result, to_dtype(offsets, torch.int8)
44734472

44744473

4475-
@register_lowering(
4476-
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4477-
)
4478-
def _low_memory_max_pool_offsets_to_indices(
4479-
offsets, kernel_size, input_size, stride, padding, dilation
4480-
):
4481-
# TODO: Generalize to other max pooling flavors
4474+
def _pool_offsets_to_indices(
4475+
offsets: TensorBox,
4476+
kernel_size: Sequence[Union[int, torch.SymInt]],
4477+
input_size: Sequence[Union[int, torch.SymInt]],
4478+
increments_to_index: Callable[
4479+
[Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]],
4480+
torch._inductor.virtualized.OpsValue,
4481+
],
4482+
) -> TensorBox:
44824483
n_dim = len(kernel_size)
44834484
offsets_loader = offsets.make_loader()
4484-
4485-
def increments_to_index(dhw_inc, bh):
4486-
w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(n_dim)]
4487-
hbase = [
4488-
ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64)
4489-
for d in range(n_dim)
4490-
]
4491-
idhw = [
4492-
hbase[d] + dhw_inc[d] * ops.constant(dilation[d], torch.int64)
4493-
for d in range(n_dim)
4494-
]
4495-
return inductor_prims._flatten_index(idhw, w_in)
4485+
window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
44964486

44974487
def offsets_to_indices(idx):
4498-
bh = idx[-n_dim:]
44994488
offset = offsets_loader(idx)
4500-
k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(n_dim)]
4501-
dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const)
4502-
return increments_to_index(dhw_inc, bh)
4489+
offset_sympy = ops.indirect_indexing(offset, window_size)
4490+
reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
4491+
idhw = increments_to_index(idx, reduction_idx)
4492+
return ops.index_expr(
4493+
inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
4494+
)
45034495

45044496
indices = Pointwise.create(
45054497
device=offsets.get_device(),
@@ -4510,6 +4502,27 @@ def offsets_to_indices(idx):
45104502
return indices
45114503

45124504

4505+
@register_lowering(
4506+
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
4507+
)
4508+
def _low_memory_max_pool_offsets_to_indices(
4509+
offsets, kernel_size, input_size, stride, padding, dilation
4510+
):
4511+
# TODO: Generalize to other max pooling flavors
4512+
n_dim = len(kernel_size)
4513+
4514+
def increments_to_index(idx, reduction_idx):
4515+
bh = idx[-n_dim:]
4516+
return [
4517+
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
4518+
for i in range(n_dim)
4519+
]
4520+
4521+
return _pool_offsets_to_indices(
4522+
offsets, kernel_size, input_size, increments_to_index
4523+
)
4524+
4525+
45134526
def _max_pool_with_indices(
45144527
x,
45154528
kernel_size,
@@ -5013,11 +5026,6 @@ def inner_fn_max_idx(idx):
50135026
return rv, ri
50145027

50155028

5016-
fallback_fractional_max_pool2d = fallback_handler(
5017-
aten.fractional_max_pool2d.default, add_to_fallback_set=False
5018-
)
5019-
5020-
50215029
def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
50225030
out_sz = out_sz[dim]
50235031
in_sz = in_sz[dim]
@@ -5036,80 +5044,85 @@ def load(prefix, i):
50365044
seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha)
50375045
seq_i = ops.to_dtype(seq_i, torch.int64)
50385046
mask = ops.lt(i_expr, out_sz_expr)
5039-
return ops.where(mask, seq_i, diff)
5047+
return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz))
50405048

50415049
return load
50425050

50435051

50445052
@register_lowering(aten.fractional_max_pool2d)
50455053
def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
5046-
x.realize_hint()
5047-
*batch, inp_h, inp_w = x.get_size()
5048-
kernel_h, kernel_w = kernel_size
5049-
h_out, w_out = output_size
5054+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)
50505055

5051-
if kernel_h * kernel_w >= 25:
5052-
return fallback_fractional_max_pool2d(
5053-
x, kernel_size, output_size, random_samples
5054-
)
50555056

5056-
gen_offsets_for_dim = functools.partial(
5057-
_fractional_pooling_offsets,
5058-
samples=random_samples,
5059-
in_sz=[inp_h, inp_w],
5060-
out_sz=output_size,
5061-
kernel_sz=kernel_size,
5062-
ndims=2,
5063-
)
5057+
@register_lowering(aten.fractional_max_pool3d)
5058+
def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
5059+
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)
50645060

5065-
h_index_fn = gen_offsets_for_dim(dim=0)
5066-
w_index_fn = gen_offsets_for_dim(dim=1)
5067-
x_loader = x.make_loader()
50685061

5069-
def fn(idx, return_index):
5070-
*prefix, bh, bw = idx
5062+
def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
5063+
x.realize_hint()
5064+
batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]
50715065

5072-
h_start_index = ops.indirect_indexing(h_index_fn(prefix, 8000 bh), inp_h)
5073-
w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w)
5066+
with config.patch(unroll_reductions_threshold=25):
5067+
dhw_index_fn = [
5068+
_fractional_pooling_offsets(
5069+
samples=random_samples,
5070+
in_sz=inp_dhw,
5071+
out_sz=output_size,
5072+
kernel_sz=kernel_size,
5073+
ndims=n_dim,
5074+
dim=d,
5075+
)
5076+
for d in range(n_dim)
5077+
]
50745078

5075-
maxval = None
5076-
maxindex = None
5077-
for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
5078-
8000 val = x_loader([*prefix, h_start_index + ih, w_start_index + iw])
5079-
if return_index:
5080-
index = ops.index_expr(
5081-
(h_start_index + ih) * inp_w + w_start_index + iw, torch.int64
5082-
)
5083-
if maxindex is None:
5084-
maxindex = index
5085-
else:
5086-
maxindex = ops.where(
5087-
ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex
5088-
)
5089-
if maxval is None:
5090-
maxval = val
5091-
else:
5092-
maxval = ops.maximum(val, maxval)
5093-
if return_index:
5094-
return maxindex
5095-
else:
5096-
return maxval
5079+
x_loader = x.make_loader()
50975080

5098-
new_size = list(batch) + [h_out, w_out]
5099-
rv = Pointwise.create(
5100-
device=x.get_device(),
5101-
dtype=x.get_dtype(),
5102-
inner_fn=functools.partial(fn, return_index=False),
5103-
ranges=new_size,
5104-
)
5081+
def fn_inner(idx, reduction_idx):
5082+
prefix = idx[:-n_dim]
5083+
return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
51055084

5106-
ri = Pointwise.create(
5107-
device=x.get_device(),
5108-
dtype=torch.int64,
5109-
inner_fn=functools.partial(fn, return_index=True),
5110-
ranges=new_size,
5111-
)
5112-
return rv, ri
5085+
def increments_to_index(idx, reduction_idx):
5086+
prefix = idx[:-n_dim]
5087+
bdhw = idx[-n_dim:]
5088+
return [
5089+
dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
5090+
for d in range(n_dim)
5091+
]
5092+
5093+
new_size = list(batch) + list(output_size)
5094+
dtype = x.get_dtype()
5095+
result = Reduction.create(
5096+
reduction_type="max",
5097+
input_node=x,
5098+
device=x.get_device(),
5099+
dst_dtype=dtype,
5100+
src_dtype=dtype,
5101+
inner_fn=fn_inner,
5102+
ranges=new_size,
5103+
reduction_ranges=kernel_size,
5104+
)
5105+
offsets = Reduction.create(
5106+
reduction_type="argmax",
5107+
input_node=x,
5108+
device=x.get_device(),
5109+
dst_dtype=torch.int64,
5110+
src_dtype=dtype,
5111+
inner_fn=fn_inner,
5112+
ranges=new_size,
5113+
reduction_ranges=kernel_size,
5114+
)
5115+
if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
5116+
# Only realize if reduction isn't unrolled
5117+
result.realize()
5118+
if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
5119+
# Only realize if reduction isn't unrolled
5120+
offsets.realize()
5121+
5122+
indices = _pool_offsets_to_indices(
5123+
offsets, kernel_size, x.shape, increments_to_index
5124+
)
5125+
return result, indices
51135126

51145127

51155128
@register_lowering(aten.upsample_nearest2d_backward.default)

0 commit comments

Comments
 (0)
0