8000 [Inductor] add support for disabling atomic adds (#151033) · pytorch/pytorch@fe96167 · GitHub 8000
[go: up one dir, main page]

Skip to content

Commit fe96167

Browse files
mlazospytorchmergebot
authored andcommitted
[Inductor] add support for disabling atomic adds (#151033)
As title Pull Request resolved: #151033 Approved by: https://github.com/eellison, https://github.com/shunting314
1 parent 67d3053 commit fe96167

File tree

3 files changed

+30
-0
lines changed

3 files changed

+30
-0
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1926,6 +1926,32 @@ def f(x, y):
19261926

19271927
self.assertEqual(f(x_ref, y_ref), out)
19281928

1929+
@unittest.skipIf(
1930+
not config.is_fbcode(),
1931+
"bfloat16 atomic add is only supported in fbcode today #97016",
1932+
)
1933+
@skipCUDAIf(
1934+
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
1935+
)
1936+
@config.patch({"bfloat16_atomic_adds_enabled": False})
1937+
def test_atomic_add_bfloat16_config(self):
1938+
def f(x, y):
1939+
return torch.index_select(x, 0, y)
1940+
1941+
x = torch.randn(
1942+
2000, 384, dtype=torch.bfloat16, device="cuda", requires_grad=True
1943+
)
1944+
y = torch.ones(713268, dtype=torch.int64, device="cuda")
1945+
x_ref = x.clone().detach().requires_grad_(True)
1946+
y_ref = y.clone().detach()
1947+
1948+
out, (_, bw_code) = run_fw_bw_and_get_code(lambda: torch.compile(f)(x, y))
1949+
fc = FileCheck()
1950+
fc.check_not("tl.atomic_add")
1951+
fc.run(bw_code)
1952+
1953+
self.assertEqual(f(x_ref, y_ref), out)
1954+
19291955
@skipCUDAIf(
19301956
not SM90OrLater, "uses bfloat16 atomic add instrs which requires SM >= 90"
19311957
)

torch/_inductor/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def prologue_fusion_enabled() -> bool:
173173
# Enable to allow using ftz variant of exponenet instruction in triton codegen.
174174
use_fast_math = os.environ.get("TORCHINDUCTOR_USE_FAST_MATH") == "1"
175175

176+
# Enable bfloat16 atomic adds (fbcode only until upstreamed to triton)
177+
bfloat16_atomic_adds_enabled = True
178+
176179
# How to organize memory under memory_planning=True:
177180
# - "none": do not try to pool storage, just reuse
178181
# - "intermediates": all non-outputs share storage, outputs each get unique storage

torch/_inductor/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,6 +2268,7 @@ def needs_fallback_due_to_atomic_add_limitations(dtype: torch.dtype) -> bool:
22682268
and dtype == torch.bfloat16
22692269
and torch.cuda.is_available()
22702270
and torch.cuda.get_device_capability() >= (9, 0)
2271+
and config.bfloat16_atomic_adds_enabled
22712272
):
22722273
return False
22732274
else:

0 commit comments

Comments
 (0)
0