@@ -1623,6 +1623,8 @@ def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
1623
1623
dtype = torch .float16
1624
1624
make_tensor = partial (torch .rand , device = device , dtype = dtype )
1625
1625
size = SdpaShape (2 , 2 , 3 , 9 ) if kernel == SDPBackend .EFFICIENT_ATTENTION else SdpaShape (2 , 2 , 3 , 257 )
1626
+ if TEST_WITH_ROCM : # On ROCM, FA and EA share the backend GPU kernels
1627
+ size = SdpaShape (2 , 2 , 3 , 257 )
1626
1628
q , k , v = make_tensor (size ), make_tensor (size ), make_tensor (size )
1627
1629
self .assertRaises (RuntimeError , lambda : torch .nn .functional .scaled_dot_product_attention (
1628
1630
q , k , v , None , 0.0 , False ))
@@ -1665,8 +1667,9 @@ def test_unaligned_tensors(self, device):
1665
1667
make_tensor = partial (torch .rand , size , device = device , dtype = dtype )
1666
1668
q , k , v = make_tensor (), make_tensor (), make_tensor ()
1667
1669
with sdpa_kernel (backends = [SDPBackend .EFFICIENT_ATTENTION ]):
1668
- self .assertRaises (RuntimeError , lambda : torch .nn .functional .scaled_dot_product_attention (
1669
- q , k , v , None , 0.0 , False ))
1670
+ ctxmgr = self .assertRaises (RuntimeError ) if not TEST_WITH_ROCM else contextlib .nullcontext ()
1671
+ with ctxmgr :
1672
+ torch .nn .functional .scaled_dot_product_attention (q , k , v , None , 0.0 , False )
1670
1673
1671
1674
@onlyCUDA
1672
1675
@unittest .skipIf (not PLATFORM_SUPPORTS_FLASH_ATTENTION , "Does not support fused SDPA or pre-SM80 hardware" )
0 commit comments