8000 [Inductor][CPP] Fix Inductor integer avg pool by DDEle · Pull Request #144059 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Inductor][CPP] Fix Inductor integer avg pool #144059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

DDEle
Copy link
Contributor
@DDEle DDEle commented Jan 2, 2025

Fixes #143738. Currently the scaler for averaging is rounded to 0 if dtype is an integer, resulting to all-zero output. This fix uses truediv instead for integer cases.

Test

pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_avg_pool1d_cpu_int64
pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_avg_pool2d_cpu_int64
pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_avg_pool3d_cpu_int64
pytest -vs ./test/inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCPU::test_comprehensive_nn_functional_local_response_norm_cpu_int64

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @chauhang @aakhundov

Copy link
pytorch-bot bot commented Jan 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/144059

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c0c38e3 with merge base a174ee2 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@DDEle
Copy link
Contributor Author
DDEle commented Jan 2, 2025

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jan 2, 2025
@leslie-fang-intel leslie-fang-intel added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 2, 2025
Copy link
Collaborator
@leslie-fang-intel leslie-fang-intel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a CPU specific issue? If not, suggest to move the UT to test/inductor/test_torchinductor.py

@DDEle
Copy link
Contributor Author
DDEle commented Jan 2, 2025

Is it a CPU specific issue? If not, suggest to move the UT to test/inductor/test_torchinductor.py

The original bug reporter said that BTW, cuda would reject the Long dtype. So, I guess it is sort of CPU specific?

In detail, CUDA complains that RuntimeError: "avg_pool2d_out_cuda_frame" not implemented for 'Long'. XPU complains that RuntimeError: "avg_pool2d_xpu" not implemented for 'Long' .

@DDEle
Copy link
Contributor Author
DDEle commented Jan 6, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@DDEle DDEle deleted the fix-inductor-int-avgpool branch January 6, 2025 01:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[inductor] [cpu] [silent] avg_pool2d incorrectly process int64
6 participants
0