-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[inductor] support dilation in max_pool2d lowering #148209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
3b5d46f
Update
isuruf ea68534
Update
isuruf 43d027d
Update
isuruf 2cbd545
Update
isuruf d519202
Update
isuruf f65ee2b
Update
isuruf 85d0971
Update
isuruf 23415c0
index_expr -> constant
isuruf 40769ad
Update
isuruf 0d28bb7
update testing
isuruf 7a18cf8
update mps tests
isuruf 935370a
Update
isuruf 235476a
update mps test
isuruf 4a15f7a
Update
isuruf File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4319,14 +4319,22 @@ def load(index): | |
return load | ||
|
||
|
||
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): | ||
def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None): | ||
if dilation is None: | ||
dilation = [1] * len(padding) | ||
|
||
x_out = FloorDiv( | ||
x + 2 * padding[i] - (kernel_size[i] - 1) + (stride[i] - 1), stride[i] | ||
x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1), | ||
stride[i], | ||
) | ||
|
||
if ceil_mode: | ||
x_alt = FloorDiv( | ||
x + 2 * padding[i] - (kernel_size[i] - 1) + 2 * (stride[i] - 1), stride[i] | ||
x | ||
+ 2 * padding[i] | ||
- dilation[i] * (kernel_size[i] - 1) | ||
+ 2 * (stride[i] - 1), | ||
stride[i], | ||
) | ||
if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0: | ||
# Sliding windows must start within the input or left padding | ||
|
@@ -4341,10 +4349,10 @@ def pooling_size(x, i, kernel_size, stride, padding, ceil_mode): | |
return x_out, ceil_mode | ||
|
||
|
||
def should_fallback_max_pool2d_with_indices(kernel_size, dilation): | ||
def should_fallback_max_pool2d_with_indices(kernel_size): | ||
kernel_size = pad_listlike(kernel_size, 2) | ||
window_size = kernel_size[0] * kernel_size[1] | ||
return (window_size > 25) or any(d > 1 for d in dilation) | ||
return window_size > 25 | ||
|
||
|
||
def max_pool2d_checks( | ||
|
@@ -4369,7 +4377,7 @@ def max_pool2d_checks( | |
assert len(dilation) == 2 | ||
assert len(x.get_size()) in (3, 4) | ||
|
||
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size, dilation) | ||
use_fallback = should_fallback_max_pool2d_with_indices(kernel_size) | ||
if assert_fallback is not None: | ||
assert use_fallback == assert_fallback | ||
|
||
|
@@ -4387,8 +4395,12 @@ def _max_pool2d_with_offsets( | |
x.realize_hint() | ||
*batch, h, w = x.get_size() | ||
|
||
h_out, ceil_mode1 = pooling_size(h, 0, kernel_size, stride, padding, ceil_mode) | ||
w_out, ceil_mode2 = pooling_size(w, 1, kernel_size, stride, padding, ceil_mode) | ||
h_out, ceil_mode1 = pooling_size( | ||
h, 0, kernel_size, stride, padding, ceil_mode, dilation=dilation | ||
) | ||
w_out, ceil_mode2 = pooling_size( | ||
w, 1, kernel_size, stride, padding, ceil_mode, dilation=dilation | ||
) | ||
|
||
dtype = x.dtype | ||
min_value = ( | ||
|
@@ -4398,7 +4410,14 @@ def _max_pool2d_with_offsets( | |
) | ||
|
||
new_size = list(batch) + [h_out, w_out] | ||
if padding[0] or padding[1] or ceil_mode1 or ceil_mode2: | ||
if ( | ||
padding[0] | ||
or padding[1] | ||
or ceil_mode1 | ||
or ceil_mode2 | ||
or (dilation[0] > 1) | ||
or (dilation[1] > 1) | ||
): | ||
x_loader = constant_boundary_condition(x, min_value, dim=2) | ||
else: | ||
x_loader = x.make_loader() | ||
|
@@ -4408,7 +4427,10 @@ def _max_pool2d_with_offsets( | |
def fn_inner(idx, reduction_idx): | ||
prefix = idx[:-dim] | ||
bh = idx[-dim:] | ||
ih = [bh[i] * stride[i] + reduction_idx[i] - padding[i] for i in range(dim)] | ||
ih = [ | ||
(bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i] | ||
for i in range(dim) | ||
] | ||
return x_loader([*prefix, *ih]) | ||
|
||
result = Reduction.create( | ||
|
@@ -4476,7 +4498,7 @@ def _low_memory_max_pool2d_with_offsets( | |
prims._low_memory_max_pool2d_offsets_to_indices, type_promotion_kind=None | ||
) | ||
def _low_memory_max_pool2d_offsets_to_indices( | ||
offsets, kernel_width, input_width, stride, padding | ||
offsets, kernel_width, input_width, stride, padding, dilation | ||
): | ||
# TODO: Generalize to other max pooling flavors, and arbitrary dim | ||
|
||
|
@@ -4486,8 +4508,8 @@ def increments_to_index(h_inc, w_inc, bh, bw): | |
w_in = ops.index_expr(input_width, torch.int64) | ||
hbase = ops.index_expr(bh * stride[0] - padding[0], torch.int64) | ||
wbase = ops.index_expr(bw * stride[1] - padding[1], torch.int64) | ||
ih = hbase + h_inc | ||
iw = wbase + w_inc | ||
ih = hbase + h_inc * ops.constant(dilation[0], torch.int64) | ||
iw = wbase + w_inc * ops.constant(dilation[1], torch.int64) | ||
return ih * w_in + iw | ||
|
||
def offsets_to_indices(idx): | ||
|
@@ -4507,12 +4529,6 @@ def offsets_to_indices(idx): | |
return indices | ||
|
||
|
||
fallback_max_pool2d_with_indices = fallback_handler( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🎉 |
||
aten.max_pool2d_with_indices.default, | ||
add_to_fallback_set=False, | ||
) | ||
|
||
|
||
# Fallback when we do not decompose to the low-memory path. | ||
@register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None) | ||
def max_pool2d_with_indices( | ||
|
@@ -4527,17 +4543,12 @@ def max_pool2d_with_indices( | |
x, kernel_size, stride, padding, dilation | ||
) | ||
|
||
if any(d > 1 for d in dilation): | ||
return fallback_max_pool2d_with_indices( | ||
x, kernel_size, stride, padding, dilation, ceil_mode=ceil_mode | ||
) | ||
|
||
out, offsets = _max_pool2d_with_offsets( | ||
x, kernel_size, stride, padding, dilation, ceil_mode | ||
) | ||
|
||
indices = _low_memory_max_pool2d_offsets_to_indices( | ||
offsets, kernel_size[-1], x.shape[-1], stride, padding | ||
offsets, kernel_size[-1], x.shape[-1], stride, padding, dilation | ||
) | ||
|
||
return out, indices | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.