8000 Inductor to fail gracefully on Voltas for bf16 tensors · pytorch/pytorch@ddf4bb2 · GitHub
[go: up one dir, main page]

Skip to content

Commit ddf4bb2

Browse files
committed
Inductor to fail gracefully on Voltas for bf16 tensors
Voltas do not have a HW support for bfloat16 datatype, but this type is emulated in software, so PyTorch eager can use bfloat16 tensors, but not Triton So if graph with either CUDA bf16 input or output tensors is used, raise warning and skip the frame Fixes #118122 and #118581
1 parent 1c75ddf commit ddf4bb2

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

torch/_inductor/compile_fx.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,7 @@ def fx_codegen_and_compile(
858858
else:
859859
output_strides.append(None)
860860

861+
_check_triton_bf16_support(graph)
861862
compiled_fn = graph.compile_to_fn()
862863
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
863864
metrics.num_bytes_accessed += num_bytes
@@ -1628,3 +1629,31 @@ def wrapper(*args):
16281629
return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args)))
16291630

16301631
return wrapper
1632+
1633+
1634+
def _check_triton_bf16_support(graph: GraphLowering) -> None:
1635+
def warn_and_skip(device) -> None:
1636+
from torch._dynamo.exc import SkipFrame
1637+
device_props = torch.cuda.get_device_properties(device)
1638+
warnings.warn(f"{device_props.name} does not support bfloat16 compilation natively, skipping")
1639+
raise SkipFrame("BF16 is not supported")
1640+
1641+
for inp in graph.graph_inputs.values():
1642+
device = inp.get_device()
1643+
if device.type != "cuda" or inp.get_dtype() != torch.bfloat16:
1644+
continue
1645+
# Print warning and skip frame if attempting to compile for bfloat16
1646+
# on device without hardware support for dtype
1647+
if torch.cuda.is_bf16_supported(including_emulation=False):
1648+
return
1649+
warn_and_skip(device)
1650+
1651+
for out in graph.graph_outputs:
1652+
device = out.get_device()
1653+
if device.type != "cuda" or out.get_dtype() != torch.bfloat16:
1654+
continue
1655+
# Print warning and skip frame if attempting to compile for bfloat16
1656+
# on device without hardware support for dtype
1657+
if torch.cuda.is_bf16_supported(including_emulation=False):
1658+
return
1659+
warn_and_skip(device)

torch/cuda/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def is_available() -> bool:
128128
return torch._C._cuda_getDeviceCount() > 0
129129

130130

131-
def is_bf16_supported():
131+
def is_bf16_supported(including_emulation: bool = True):
132132
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16."""
133133
# Check for ROCm, if true return true, no ROCM_VERSION check required,
134134
# since it is supported on AMD GPU archs.
@@ -147,6 +147,9 @@ def is_bf16_supported():
147147
):
148148
return True
149149

150+
if not including_emulation:
151+
return False
152+
150153
# Finally try to create a bfloat16 device.
151154
return _check_bf16_tensor_supported(device)
152155

0 commit comments

Comments
 (0)
0