8000 [inductor] lowering for fractional_max_pool3d by isuruf · Pull Request #148630 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] lowering for fractional_max_pool3d #148630

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4859,18 +4859,16 @@ def fn(x, samples):

@xfail_if_mps_unimplemented
def test_fractional_max_pool2d2(self):
# fallback for larger kernel size
# large kernel size without unrolling

def fn(x, samples):
return aten.fractional_max_pool2d(x, (6, 5), (3, 3), samples)

torch._inductor.metrics.generated_kernel_count = 0
self.common(
fn,
(torch.randn(2, 4, 36, 36), torch.rand(2, 4, 2)),
check_lowp=False,
)
assertGeneratedKernelCountEqual(self, 0)

@xfail_if_mps_unimplemented
def test_fractional_max_pool2d3(self):
Expand Down
1 change: 0 additions & 1 deletion test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def run(*ex, **kwargs):
"test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_fractional_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_avg_pool2d_backward4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
Expand Down
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def _print_ToFloat(self, expr: sympy.Expr) -> str:
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
x = self.doprint(expr.args[0])
return f"static_cast<int>(metal::floor({x}))"
return f"static_cast<int>(metal::floor(static_cast<float>({x})))"

_print_floor = _print_FloorToInt

def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
Expand Down
9 changes: 8 additions & 1 deletion torch/_inductor/inductor_prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,20 @@ def eager_prepare_softmax(x: Tensor, dim: int) -> tuple[Tensor, Tensor]:


def _flattened_index_to_nd(indices, width):
import sympy

from torch.utils._sympy.functions import FloorDiv

dim = len(width)

if dim == 1:
return [indices]
elif dim >= 2:
m = functools.reduce(operator.mul, width[1:])
ih = indices // m
if isinstance(indices, sympy.Expr) or isinstance(m, sympy.Expr):
ih = FloorDiv(indices, m)
else:
ih = indices // m
indices_new = indices - (ih * m)
return [ih, *_flattened_index_to_nd(indices_new, width[1:])]
else:
Expand Down
191 changes: 102 additions & 89 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2600,7 +2600,6 @@ def is_aligned(x):
# WIP
make_fallback(aten._adaptive_avg_pool3d) # @isuruf
make_fallback(aten.adaptive_max_pool3d) # @isuruf
make_fallback(aten.fractional_max_pool3d) # @isuruf
make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet


Expand Down Expand Up @@ -4472,34 +4471,27 @@ def _low_memory_max_pool_with_offsets(
return result, to_dtype(offsets, torch.int8)


@register_lowering(
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
)
def _low_memory_max_pool_offsets_to_indices(
offsets, kernel_size, input_size, stride, padding, dilation
):
# TODO: Generalize to other max pooling flavors
def _pool_offsets_to_indices(
offsets: TensorBox,
kernel_size: Sequence[Union[int, torch.SymInt]],
input_size: Sequence[Union[int, torch.SymInt]],
increments_to_index: Callable[
[Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]],
torch._inductor.virtualized.OpsValue,
],
) -> TensorBox:
n_dim = len(kernel_size)
offsets_loader = offsets.make_loader()

def increments_to_index(dhw_inc, bh):
w_in = [ops.index_expr(input_size[d], torch.int64) for d in range(n_dim)]
hbase = [
EDBE ops.index_expr(bh[d] * stride[d] - padding[d], torch.int64)
for d in range(n_dim)
]
idhw = [
hbase[d] + dhw_inc[d] * ops.constant(dilation[d], torch.int64)
for d in range(n_dim)
]
return inductor_prims._flatten_index(idhw, w_in)
window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))

def offsets_to_indices(idx):
bh = idx[-n_dim:]
offset = offsets_loader(idx)
k_const = [ops.constant(kernel_size[d], torch.int32) for d in range(n_dim)]
dhw_inc = inductor_prims._flattened_index_to_nd(offset, k_const)
return increments_to_index(dhw_inc, bh)
offset_sympy = ops.indirect_indexing(offset, window_size)
reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
idhw = increments_to_index(idx, reduction_idx)
return ops.index_expr(
inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
)

indices = Pointwise.create(
device=offsets.get_device(),
Expand All @@ -4510,6 +4502,27 @@ def offsets_to_indices(idx):
return indices


@register_lowering(
prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
)
def _low_memory_max_pool_offsets_to_indices(
offsets, kernel_size, input_size, stride, padding, dilation
):
# TODO: Generalize to other max pooling flavors
n_dim = len(kernel_size)

def increments_to_index(idx, reduction_idx):
bh = idx[-n_dim:]
return [
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
for i in range(n_dim)
]

return _pool_offsets_to_indices(
offsets, kernel_size, input_size, increments_to_index
)


def _max_pool_with_indices(
x,
kernel_size,
Expand Down Expand Up @@ -5013,11 +5026,6 @@ def inner_fn_max_idx(idx):
return rv, ri


fallback_fractional_max_pool2d = fallback_handler(
aten.fractional_max_pool2d.default, add_to_fallback_set=False
)


def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
out_sz = out_sz[dim]
in_sz = in_sz[dim]
Expand All @@ -5036,80 +5044,85 @@ def load(prefix, i):
seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha)
seq_i = ops.to_dtype(seq_i, torch.int64)
mask = ops.lt(i_expr, out_sz_expr)
return ops.where(mask, seq_i, diff)
return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz))

return load


@register_lowering(aten.fractional_max_pool2d)
def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
x.realize_hint()
*batch, inp_h, inp_w = x.get_size()
kernel_h, kernel_w = kernel_size
h_out, w_out = output_size
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)

if kernel_h * kernel_w >= 25:
return fallback_fractional_max_pool2d(
x, kernel_size, output_size, random_samples
)

gen_offsets_for_dim = functools.partial(
_fractional_pooling_offsets,
samples=random_samples,
in_sz=[inp_h, inp_w],
out_sz=output_size,
kernel_sz=kernel_size,
ndims=2,
)
@register_lowering(aten.fractional_max_pool3d)
def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)

h_index_fn = gen_offsets_for_dim(dim=0)
w_index_fn = gen_offsets_for_dim(dim=1)
x_loader = x.make_loader()

def fn(idx, return_index):
*prefix, bh, bw = idx
def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
x.realize_hint()
batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]

h_start_index = ops.indirect_indexing(h_index_fn(prefix, bh), inp_h)
w_start_index = ops.indirect_indexing(w_index_fn(prefix, bw), inp_w)
with config.patch(unroll_reductions_threshold=25):
dhw_index_fn = [
_fractional_pooling_offsets(
samples=random_samples,
in_sz=inp_dhw,
out_sz=output_size,
kernel_sz=kernel_size,
ndims=n_dim,
dim=d,
)
for d in range(n_dim)
]

maxval = None
maxindex = None
for ih, iw in itertools.product(range(kernel_size[0]), range(kernel_size[1])):
val = x_loader([*prefix, h_start_index + ih, w_start_index + iw])
if return_index:
index = ops.index_expr(
(h_start_index + ih) * inp_w + w_start_index + iw, torch.int64
)
if maxindex is None:
maxindex = index
else:
maxindex = ops.where(
ops.or_(ops.gt(val, maxval), ops.isnan(val)), index, maxindex
)
if maxval is None:
maxval = val
else:
maxval = ops.maximum(val, maxval)
if return_index:
return maxindex
else:
return maxval
x_loader = x.make_loader()

new_size = list(batch) + [h_out, w_out]
rv = Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=functools.partial(fn, return_index=False),
ranges=new_size,
)
def fn_inner(idx, reduction_idx):
prefix = idx[:-n_dim]
return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])

ri = Pointwise.create(
device=x.get_device(),
dtype=torch.int64,
inner_fn=functools.partial(fn, return_index=True),
ranges=new_size,
)
return rv, ri
def increments_to_index(idx, reduction_idx):
prefix = idx[:-n_dim]
bdhw = idx[-n_dim:]
return [
dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
for d in range(n_dim)
]

new_size = list(batch) + list(output_size)
dtype = x.get_dtype()
result = Reduction.create(
reduction_type="max",
input_node=x,
device=x.get_device(),
dst_dtype=dtype,
src_dtype=dtype,
inner_fn=fn_inner,
ranges=new_size,
reduction_ranges=kernel_size,
)
offsets = Reduction.create(
reduction_type="argmax",
input_node=x,
device=x.get_device(),
dst_dtype=torch.int64,
src_dtype=dtype,
inner_fn=fn_inner,
ranges=new_size,
reduction_ranges=kernel_size,
)
if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
# Only realize if reduction isn't unrolled
result.realize()
if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
# Only realize if reduction isn't unrolled
offsets.realize()

indices = _pool_offsets_to_indices(
offsets, kernel_size, x.shape, increments_to_index
)
return result, indices


@register_lowering(aten.upsample_nearest2d_backward.default)
Expand Down
Loading
0