8000 `torch.cuda.is_bf16_compatible()` output inconsistent with with TorchInductor support · Issue #118122 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.cuda.is_bf16_compatible() output inconsistent with with TorchInductor support #118122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Tracked by #130151
denera opened this issue Jan 23, 2024 · 6 comments
Closed
Tracked by #130151
Assignees
Labels
high priority module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@denera
Copy link
denera commented Jan 23, 2024

🐛 Describe the bug

Recent change for torch.cuda.is_bf16_compatible() is now labeling Turing (sm_75) and Volta (sm_70) cards as compatible with torch.bfloat16 Tensors. Meanwhile, TorchInductor supports torch.bfloat16 only for sm_80 or higher.

This has caused Transformer Engine CI tests that would normally skip on Turing and Volta nodes to instead crash with a BackendComputerFailed error from TorchDynamo.

We are replacing the torch.cuda.is_bf16_compatible() condition with explicit checks on torch.cuda.get_device_capability() as a work-around. However, I wanted to file this issue here regardless, just in case it was an unintentional consequence and may be considered a bug.

Versions

Collecting environment information...
PyTorch version: 2.3.0a0+ebedce2
Is debug build: False
CUDA used to build PyTorch: 12.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.28.1
Libc version: glibc-2.35

Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-4.15.0-123-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.3.107
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: Tesla V100S-PCIE-32GB
GPU 1: Tesla V100S-PCIE-32GB
GPU 2: Tesla V100S-PCIE-32GB
GPU 3: Tesla V100S-PCIE-32GB
GPU 4: Tesla V100S-PCIE-32GB
GPU 5: Tesla V100S-PCIE-32GB
GPU 6: Tesla V100S-PCIE-32GB
GPU 7: Tesla V100S-PCIE-32GB

Nvidia driver version: 525.85.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          88
On-line CPU(s) list:             0-87
Vendor ID:                       GenuineIntel
Model name:                      Intel(R) Xeon(R) CPU E5-2699 v4 @ 2.20GHz
CPU family:                      6
Model:                           79
Thread(s) per core:              2
Core(s) per socket:              22
Socket(s):                       2
Stepping:                        1
CPU max MHz:                     2200.0000
CPU min MHz:                     1200.0000
BogoMIPS:                        4399.76
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cdp_l3 invpcid_single intel_ppin ssbd ibrs ibpb stibp tpr_shadow vnmi flexpriority ept vpid fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm cqm rdt_a rdseed adx smap intel_pt xsaveopt cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local dtherm arat pln pts md_clear flush_l1d
Virtualization:                  VT-x
L1d cache:                       1.4 MiB (44 instances)
L1i cache:                       1.4 MiB (44 instances)
L2 cache:                        11 MiB (44 instances)
L3 cache:                        110 MiB (2 instances)
NUMA node(s):                    2
NUMA node0 CPU(s):               0-21,44-65
NUMA node1 CPU(s):               22-43,66-87
Vulnerability Itlb multihit:     KVM: Vulnerable
Vulnerability L1tf:              Mitigation; PTE Inversion; VMX vulnerable
Vulnerability Mds:               Vulnerable; SMT vulnerable
Vulnerability Meltdown:          Vulnerable
Vulnerability Spec store bypass: Vulnerable
Vulnerability Spectre v1:        Vulnerable: __user pointer sanitization and usercopy barriers only; no swapgs barriers
Vulnerability Spectre v2:        Vulnerable, IBPB: disabled, STIBP: disabled
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Vulnerable

Versions of relevant libraries:
[pip3] numpy==1.24.4
[pip3] onnx==1.15.0rc2
[pip3] optree==0.10.0
[pip3] pytorch-quantization==2.1.2
[pip3] torch==2.3.0a0+ebedce2
[pip3] torch-tensorrt==2.2.0a0
[pip3] torchdata==0.7.1a0
[pip3] torchtext==0.17.0a0
[pip3] torchvision==0.18.0a0
[pip3] triton==2.2.0
[conda] Could not collect

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @chauhang @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov

@ezyang
Copy link
Contributor
ezyang commented Jan 24, 2024

cc @ysiraichi

@mlazos mlazos added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: inductor labels Jan 29, 2024
@ysiraichi
Copy link
Collaborator
ysiraichi commented Jan 30, 2024

If eager works with bfloat16 data types on CUDA, I would say that torch.cuda.is_bf16_available() == True makes sense. If I understand this correctly, the error only occurs with inductor. Maybe we should have something similar within inductor: torch._inductor.is_bf16_available().

@ezyang any thoughts?

@malfet
Copy link
Contributor
malfet commented Jan 30, 2024

Well, imo we should fix either Inductor or triton to produce working binaries rather than crash with unsupported instructions
As I have a local setup, can try working on a fix

@sagelywizard
Copy link

Hi from Colab!

This is causing problems for users on Colab (example). The gist is that torch.compile w/bf16 on T4 causes some semi-opaque errors. e.g.

---------------------------------------------------------------------------
BackendCompilerFailed                     Traceback (most recent call last)
[<ipython-input-23-085f0fe5825f>](https://localhost:8080/#) in <cell line: 1>()
----> 1 fast_add(a.to(torch.bfloat16), b.to(torch.bfloat16))

49 frames
[/usr/lib/python3.10/concurrent/futures/_base.py](https://localhost:8080/#) in __get_result(self)
    401         if self._exception:
    402             try:
--> 403                 raise self._exception
    404             finally:
    405                 # Break a reference cycle with the exception in self._exception

BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/compile-ptx-src-2d5bbc, line 50; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-2d5bbc, line 50; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-2d5bbc, line 58; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-2d5bbc, line 58; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-2d5bbc, line 66; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-2d5bbc, line 66; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal   : Ptx assembly aborted due to errors

Seems like poor UX to me. Any thoughts on the appropriate fix here? Should the fix be in triton or torch? My understanding is that XLA handles this by casting to f32 (e.g. openxla/xla#12429)

@msaroufim
Copy link
Member

Considering this affects google colab marking as high pri until we find a way to unblock

@malfet malfet added this to the 2.4.0 milestone Jun 11, 2024
@malfet malfet assigned malfet and unassigned ysiraichi Jun 11, 2024
@malfet
Copy link
Contributor
malfet commented Jun 11, 2024

Ok, I think we need to separate whether something is "supported"(read emulated) for eager vs compile.

Though it does not look that those two changes are related, i.e. even if is_bf16_supported returns False compile still fails to produce the correct code.

Also, this is not a regression, i.e. 2.2 had the same behavior, see https://colab.research.google.com/drive/1rIy_MJSUcdV8nQu_uuFZsniLR6uhS1BR?usp=sharing

Where

import torch
import triton

torch.cuda._check_bf16_tensor_supported = lambda x: False
print(torch.__version__, triton.__version__, torch.cuda.get_device_properties(0), torch.cuda.is_bf16_supported())


@torch.compile
def fn(inp, src, index):
            return inp.scatter_add(0, index, src)

with torch.device("cuda"):
  dtype = torch.bfloat16
  inp = torch.zeros(3, 5, dtype=dtype)
  src = torch.ones((2, 5), dtype=dtype)
  index = torch.tensor([[0, 1, 2, 0, 0]])

print(fn(inp, src, index))

Results in

2.2.2+cu121 2.2.0 _CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=15102MB, multi_processor_count=40) False
---------------------------------------------------------------------------
BackendCompilerFailed                     Traceback (most recent call last)
<ipython-input-4-342bd5fc6a13> in <cell line: 18>()
     16   index = torch.tensor([[0, 1, 2, 0, 0]])
     17 
---> 18 print(fn(inp, src, index))

45 frames
/usr/lib/python3.10/concurrent/futures/_base.py in __get_result(self)
    401         if self._exception:
    402             try:
--> 403                 raise self._exception
    404             finally:
    405                 # Break a reference cycle with the exception in self._exception

BackendCompilerFailed: backend='inductor' raised:
RuntimeError: Internal Triton PTX codegen error: 
ptxas /tmp/compile-ptx-src-6fbc78, line 48; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-6fbc78, line 48; error   : Feature 'cvt with .f32.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-6fbc78, line 52; error   : Feature '.bf16' requires .target sm_80 or higher
ptxas /tmp/compile-ptx-src-6fbc78, line 52; error   : Feature 'cvt.bf16.f32' requires .target sm_80 or higher
ptxas fatal   : Ptx assembly aborted due to errors

malfet added a commit that referenced this issue Jun 22, 2024
Voltas do not have a HW support for bfloat16 datatype, but this type is emulated in software, so PyTorch eager can use bfloat16 tensors, but not Triton
So if graph with either CUDA bf16 input or output tensors is used, raise warning and skip the frame

Fixes #118122 and #118581
pytorchbot pushed a commit that referenced this issue Jun 27, 2024
Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame.

Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`.

Test plan: Modify `is_bf16_supported` to return False and see that warning is generated

Fixes #118122 and #118581

Pull Request resolved: #129288
Approved by: https://github.com/eqy, https://github.com/jansel

(cherry picked from commit 14dc08d)
atalman pushed a commit that referenced this issue Jun 28, 2024
Inductor to fail gracefully on Voltas for bf16 tensors (#129288)

Volta(sm_7x) do not have a HW support for bfloat16 datatype, and while it is is emulated to ted in software, so PyTorch eager can use bfloat16 tensors, but not in Triton. So if graph with either CUDA bf16 input or output tensors is used, raise warnings and skip the frame.

Add optional parameter `including_emulation` to `torch.cuda.is_bf16_supported` method and call it from `torch._inductor.compile_fx. _check_triton_bf16_support`.

Test plan: Modify `is_bf16_supported` to return Fals
806E
e and see that warning is generated

Fixes #118122 and #118581

Pull Request resolved: #129288
Approved by: https://github.com/eqy, https://github.com/jansel

(cherry picked from commit 14dc08d)

Co-authored-by: Nikita Shulga <nshulga@meta.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants
0