8000 `test_scatter_bf16_cuda` fails on V100 · Issue #118581 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

test_scatter_bf16_cuda fails on V100 #118581

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Tracked by #130151
malfet opened this issue Jan 29, 2024 · 8 comments
Closed
Tracked by #130151

test_scatter_bf16_cuda fails on V100 #118581

malfet opened this issue Jan 29, 2024 · 8 comments
Assignees
Labels
high priority module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@malfet
Copy link
Contributor
malfet commented Jan 29, 2024

🐛 Describe the bug

While running inductor CI on V100, I've found that above-mentioned test fails with unsupported PTX instructions:

% python3 inductor/test_torchinductor.py -v -k test_scatter_bf16_cuda
...
RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/compile-ptx-src-f5ac42, line 48; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-f5ac42, line 48; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-f5ac42, line 52; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-f5ac42, line 52; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal   : Ptx assembly aborted due to errors

Versions

2.2, nightly

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @aakhundov @ColinPeppler

@tringwald
Copy link
Collaborator

Probably related to this #118122.

@eellison eellison added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 30, 2024
@eellison
Copy link
Contributor

@malfet, still working on this ?

@anijain2305
Copy link
Contributor

Comments from triage meeting

  • V100 has limited support for bfloat16. Since Triton uses more advanced intrinsics, should we fallback?
  • Conclusion - fallback only for a few ops - like scatter.

@masnesral
Copy link
Contributor

We're doing the weekly check-in on hi-pri issues that haven't been updated in a month. @malfet, any update?

@zou3519 zou3519 added this to the 2.4.0 milestone Jun 4, 2024
@malfet
Copy link
Contributor Author
malfet commented Jun 11, 2024

This sounds like a duplicate of #118122

malfet added a commit that referenced this issue Jun 22, 2024
Voltas do not have a HW support for bfloat16 datatype, but this type is emulated in software, so PyTorch eager can use bfloat16 tensors, but not Triton
So if graph with either CUDA bf16 input or output tensors is used, raise warning and skip the frame

Fixes #118122 and #118581
pytorchmergebot pushed a commit that referenced this issue Jun 25, 2024
Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame.

Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`.

Test plan: Modify `is_bf16_supported` to return False and see that warning is generated

Fixes #118122 and #118581

Pull Request resolved: #129288
Approved by: https://github.com/eqy, https://github.com/jansel
@malfet
Copy link
Contributor Author
malfet commented Jun 25, 2024

Closing, were fixed by skipping compile on Volta's for bf16 dtype

@malfet malfet closed this as completed Jun 25, 2024
pytorchbot pushed a commit that referenced this issue Jun 27, 2024
Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame.

Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`.

Test plan: Modify `is_bf16_supported` to return False and see that warning is generated

Fixes #118122 and #118581

Pull Request resolved: #129288
Approved by: https://github.com/eqy, https://github.com/jansel

(cherry picked from commit 14dc08d)
atalman pushed a commit that referenced this issue Jun 28, 2024
Inductor to fail gracefully on Voltas for bf16 tensors (#129288)

Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame.

Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`.

Test plan: Modify `is_bf16_supported` to return False and see that warning is generated

Fixes #118122 and #118581

Pull Request resolved: #129288
Approved by: https://github.com/eqy, https://github.com/jansel

(cherry picked from commit 14dc08d)

Co-authored-by: Nikita Shulga <nshulga@meta.com>
@atalman
Copy link
Contributor
atalman commented Jul 19, 2024

Validated in Colab:

2.4.0+cu121 _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15102MB, multi_processor_count=40) True
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:1607: UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping
  warnings.warn(
tensor([[1., 0., 0., 1., 1.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]], device='cuda:0', dtype=torch.bfloat16)

@steveepreston
Copy link

UserWarning: Tesla T4 does not support bfloat16 compilation natively, skipping

same error

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants
0