8000 [inductor] [bug fix] align `avg_pool` with eager when handling `uint` by shaoyuyoung · Pull Request #144313 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[inductor] [bug fix] align avg_pool with eager when handling uint #144313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from

Conversation

shaoyuyoung
Copy link
Contributor
@shaoyuyoung shaoyuyoung commented Jan 7, 2025

Fixes #144310

We just need to add a check in lowering

updated: we add the error checking in meta registration

UT

 pytest -s -v test/inductor/test_torchinductor.py -k test_avg_pool_errors_with_uint

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

Copy link
pytorch-bot bot commented Jan 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144313

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 467c37f with merge base 69b883d (image):

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@shaoyuyoung
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jan 7, 2025
@cpuhrsch cpuhrsch requested review from yanboliang and jgong5 January 7, 2025 06:39
@cpuhrsch cpuhrsch added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 7, 2025
@@ -6215,6 +6215,18 @@ def forward(self, x):
):
model(x)

def test_avg_pool_errors_with_uint(self):
torch._dynamo.config.recompile_limit = 12
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we do this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my local debugging, I found that the default is recompile_limit=8. But in this case, we need to compile the model 12 times (3 different dims * 4 different uint).
Not sure how the UT in test_torchinductor.py set the recompile_limit, but if I understand correctly, the different UTs should be independent of each other and not affect each other, so I added this config here.
Feel free to correct me if I am wrong

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @torch._dynamo.config.patch(recompile_limit=12)

@@ -4957,6 +4957,9 @@ def _avg_poolnd(
assert len(stride) == dim
assert len(padding) == dim
assert len(x.get_size()) in (dim + 1, dim + 2)
if x.get_dtype() in (torch.uint8, torch.uint16, torch.uint32, torch.uint64):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shaoyuyoung, the general contract is that the operator meta function should handle error checking. can we move this to

@register_meta(aten.avg_pool2d.default)
def meta_avg_pool2d(
? not sure where the avg_pool1d registration is

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds great, I'm currently looking for avg_pool1d registration.
@jansel , any comment? Because this PR follows the previous PR #143762

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this check in the meta function is a better fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for reviewing, I will do this work in my spare time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? not sure where the avg_pool1d registration is

It seems that avg_pool1d shares the same registration with 2d

@@ -6215,6 +6215,18 @@ def forward(self, x):
):
model(x)

def test_avg_pool_errors_with_uint(self):
torch._dynamo.config.recompile_limit = 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use @torch._dynamo.config.patch(recompile_limit=12)

@@ -4957,6 +4957,9 @@ def _avg_poolnd(
assert len(stride) == dim
assert len(padding) == dim
assert len(x.get_size()) in (dim + 1, dim + 2)
if x.get_dtype() in (torch.uint8, torch.uint16, torch.uint32, torch.uint64):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Putting this check in the meta function is a better fix.

@shaoyuyoung
Copy link
Contributor Author

have updated, mind helping me review again? :)

@jansel
Copy link
Contributor
jansel commented Jan 15, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 15, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@shaoyuyoung
Copy link
Contributor Author
shaoyuyoung commented Jan 16, 2025

CI seems broken just now (?)
merge failed

@eellison
Copy link
Contributor

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@shaoyuyoung shaoyuyoung deleted the fx_avg_pool_uint8 branch January 17, 2025 02:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[inductor] [dtype propogation] avg_pool1d,2d,3d pass the check when handling uint8,16,32,64 while eager throws the error
7 participants
0