-
Notifications
You must be signed in to change notification settings - Fork 24.3k
torch.cuda.is_bf16_compatible()
output inconsistent with with TorchInductor support
#118122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
cc @ysiraichi |
If eager works with bfloat16 data types on CUDA, I would say that @ezyang any thoughts? |
Well, imo we should fix either Inductor or triton to produce working binaries rather than crash with unsupported instructions |
Hi from Colab! This is causing problems for users on Colab (example). The gist is that torch.compile w/bf16 on T4 causes some semi-opaque errors. e.g.
Seems like poor UX to me. Any thoughts on the appropriate fix here? Should the fix be in triton or torch? My understanding is that XLA handles this by casting to f32 (e.g. openxla/xla#12429) |
Considering this affects google colab marking as high pri until we find a way to unblock |
Ok, I think we need to separate whether something is "supported"(read emulated) for eager vs compile. Though it does not look that those two changes are related, i.e. even if Also, this is not a regression, i.e. 2.2 had the same behavior, see https://colab.research.google.com/drive/1rIy_MJSUcdV8nQu_uuFZsniLR6uhS1BR?usp=sharing Where import torch
import triton
torch.cuda._check_bf16_tensor_supported = lambda x: False
print(torch.__version__, triton.__version__, torch.cuda.get_device_properties(0), torch.cuda.is_bf16_supported())
@torch.compile
def fn(inp, src, index):
return inp.scatter_add(0, index, src)
with torch.device("cuda"):
dtype = torch.bfloat16
inp = torch.zeros(3, 5, dtype=dtype)
src = torch.ones((2, 5), dtype=dtype)
index = torch.tensor([[0, 1, 2, 0, 0]])
print(fn(inp, src, index)) Results in
|
Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame. Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`. Test plan: Modify `is_bf16_supported` to return False and see that warning is generated Fixes #118122 and #118581 Pull Request resolved: #129288 Approved by: https://github.com/eqy, https://github.com/jansel (cherry picked from commit 14dc08d)
Inductor to fail gracefully on Voltas for bf16 tensors (#129288) Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame. Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`. Test plan: Modify `is_bf16_supported` to return Fals 806E e and see that warning is generated Fixes #118122 and #118581 Pull Request resolved: #129288 Approved by: https://github.com/eqy, https://github.com/jansel (cherry picked from commit 14dc08d) Co-authored-by: Nikita Shulga <nshulga@meta.com>
Uh oh!
There was an error while loading. Please reload this page.
🐛 Describe the bug
Recent change for
torch.cuda.is_bf16_compatible()
is now labeling Turing (sm_75) and Volta (sm_70) cards as compatible withtorch.bfloat16
Tensors. Meanwhile, TorchInductor supportstorch.bfloat16
only for sm_80 or higher.This has caused Transformer Engine CI tests that would normally skip on Turing and Volta nodes to instead crash with a
BackendComputerFailed
error from TorchDynamo.We are replacing the
torch.cuda.is_bf16_compatible()
condition with explicit checks ontorch.cuda.get_device_capability()
as a work-around. However, I wanted to file this issue here regardless, just in case it was an unintentional consequence and may be considered a bug.Versions
cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov
The text was updated successfully, but these errors were encountered: