8000 [ROCm] Prevent accidental enablement of efficient attention. (#134531) · ROCm/pytorch@b4adece · GitHub
[go: up one dir, main page]

Skip to content

Commit b4adece

Browse files
[ROCm] Prevent accidental enablement of efficient attention. (pytorch#134531)
[ROCm] Prevent accidental enablement of efficient attention. (pytorch#133331) Currently Efficient attention and Flash attention share the same set of GPU kernels on ROCM and have common limitations on head sizes. Fixes pytorch#132004 Pull Request resolved: pytorch#133331 Approved by: https://github.com/malfet, https://github.com/jithunnair-amd (cherry picked from commit 46ecc67) Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
1 parent 3b54c45 commit b4adece

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

aten/src/ATen/native/transformers/cuda/sdp_utils.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,7 +620,12 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
620620
check_all_tensors_on_device,
621621
check_mem_efficient_hardware_support,
622622
check_tensor_shapes,
623-
check_head_dim_size_mem_efficient);
623+
#ifdef USE_ROCM
624+
check_head_dim_size_flash
625+
#else
626+
check_head_dim_size_mem_efficient
627+
#endif
628+
);
624629
for (auto& constraint : general_constraints) {
625630
if (!constraint(params, debug)) {
626631
return false;

test/test_transformers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,6 +1457,8 @@ def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
14571457
dtype = torch.float16
14581458
make_tensor = partial(torch.rand, device=device, dtype=dtype)
14591459
size = SdpaShape(2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else SdpaShape(2, 2, 3, 257)
1460+
if TEST_WITH_ROCM: # On ROCM, FA and EA share the backend GPU kernels
1461+
size = SdpaShape(2, 2, 3, 257)
14601462
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
14611463
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
14621464
q, k, v, None, 0.0, False))
@@ -1499,8 +1501,9 @@ def test_unaligned_tensors(self, device):
14991501
make_tensor = partial(torch.rand, size, device=device, dtype=dtype)
15001502
q, k, v = make_tensor(), make_tensor(), make_tensor()
15011503
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
1502-
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
1503-
q, k, v, None, 0.0, False))
1504+
ctxmgr = self.assertRaises(RuntimeError) if not TEST_WITH_ROCM else contextlib.nullcontext()
1505+
with ctxmgr:
1506+
torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)
15041507

15051508
@onlyCUDA
15061509
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")

0 commit comments

Comments
 (0)
0