-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[inductor] [bug fix] align avg_pool
with eager when handling uint
#144313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144313
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 467c37f with merge base 69b883d ( UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@pytorchbot label "topic: not user facing" |
test/inductor/test_torchinductor.py
Outdated
@@ -6215,6 +6215,18 @@ def forward(self, x): | |||
): | |||
model(x) | |||
|
|||
def test_avg_pool_errors_with_uint(self): | |||
torch._dynamo.config.recompile_limit = 12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In my local debugging, I found that the default is recompile_limit=8
. But in this case, we need to compile the model 12 times (3 different dims * 4 different uint).
Not sure how the UT in test_torchinductor.py
set the recompile_limit
, but if I understand correctly, the different UTs should be independent of each other and not affect each other, so I added this config here.
Feel free to correct me if I am wrong
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use @torch._dynamo.config.patch(recompile_limit=12)
torch/_inductor/lowering.py
Outdated
@@ -4957,6 +4957,9 @@ def _avg_poolnd( | |||
assert len(stride) == dim | |||
assert len(padding) == dim | |||
assert len(x.get_size()) in (dim + 1, dim + 2) | |||
if x.get_dtype() in (torch.uint8, torch.uint16, torch.uint32, torch.uint64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shaoyuyoung, the general contract is that the operator meta function should handle error checking. can we move this to
pytorch/torch/_meta_registrations.py
Lines 2510 to 2511 in 9ee2422
@register_meta(aten.avg_pool2d.default) | |
def meta_avg_pool2d( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting this check in the meta function is a better fix.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for reviewing, I will do this work in my spare time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? not sure where the avg_pool1d registration is
It seems that avg_pool1d
shares the same registration with 2d
test/inductor/test_torchinductor.py
Outdated
@@ -6215,6 +6215,18 @@ def forward(self, x): | |||
): | |||
model(x) | |||
|
|||
def test_avg_pool_errors_with_uint(self): | |||
torch._dynamo.config.recompile_limit = 12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use @torch._dynamo.config.patch(recompile_limit=12)
torch/_inductor/lowering.py
Outdated
@@ -4957,6 +4957,9 @@ def _avg_poolnd( | |||
assert len(stride) == dim | |||
assert len(padding) == dim | |||
assert len(x.get_size()) in (dim + 1, dim + 2) | |||
if x.get_dtype() in (torch.uint8, torch.uint16, torch.uint32, torch.uint64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting this check in the meta function is a better fix.
have updated, mind helping me review again? :) |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 mandatory check(s) failed. The first few are: Dig deeper by viewing the failures on hud |
CI seems broken just now (?) |
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Fixes #144310
We just need to add a check in loweringupdated: we add the error checking in
meta registration
UT
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov