@@ -3838,6 +3838,176 @@ def rand_nt(sequence_list, num_heads, head_dim):
3838
3838
}
3839
3839
)
3840
3840
3841
+ class TestSDPAXpuOnly (NNTestCase ):
3842
+ """ Used to test XPU only functionality of scaled_dot_product_attention
3843
+ Mostly migrate from TestSDPACudaOnly in test/test_transformers.py
3844
+
3845
+ Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
3846
+ math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
3847
+ """
3848
+
3849
+ @parametrize ("type" , ["dense" ])
3850
+ @parametrize ("dropout" , [0.0 , 0.7 ])
3851
+ @parametrize ("dtype" , [torch .float64 , torch .float32 , torch .bfloat16 , torch .half ])
3852
+ @skipIfTorchDynamo ()
3853
+ def test_fused_sdp_choice_xpu (self , device , type : str , dropout : float , dtype : torch .dtype ):
3854
+ # Migrate from test_fused_sdp_choice_cpu
3855
+ make_tensor = partial (rand_sdpa_tensor , type = type , device = device , dtype = dtype )
3856
+ size = SdpaShape (2 , 8 , 128 , 64 )
3857
+ q , k , v = make_tensor (size ), make_tensor (size ), make_tensor (size )
3858
+ if dropout > 0.0 or dtype not in [torch .float32 , torch .bfloat16 , torch .float16 ]:
3859
+ assert torch ._fused_sdp_choice (q , k , v , dropout_p = dropout ) == SDPBackend .MATH .value
3860
+ else :
3861
+ assert torch ._fused_sdp_choice (q , k , v , dropout_p = dropout ) == SDPBackend .OVERRIDEABLE .value
3862
+
3863
+ def test_fused_attention_different_dk_dv (self , device ):
3864
+ dtype = torch .bfloat16
3865
+ make_tensor = partial (torch .rand , device = device , dtype = dtype , requires_grad = True )
3866
+ batch , num_heads , head_dim_k , head_dim_v = 32 , 16 , 128 , 64
3867
+ q_shape = SdpaShape (batch , num_heads , 1 , head_dim_k )
3868
+ k_shape = SdpaShape (batch , num_heads , 2 , head_dim_k )
3869
+ v_shape = SdpaShape (batch , num_heads , 2 , head_dim_v )
3870
+ query , key , value = make_tensor (q_shape ), make_tensor (k_shape ), make_tensor (v_shape )
3871
+
3872
+ # test that we do not dispatch to onednn for an unsupported case
3873
+ actual = F .scaled_dot_product_attention (
3874
+ query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False )
3875
+
3876
+ math_ref = torch .ops .aten ._scaled_dot_product_attention_math (
3877
+ query .float (), key .float (), value .float (), attn_mask = None , dropout_p = 0.0 , is_causal = False )[0 ]
3878
+
3879
+ self .assertEqual (actual .contiguous (), math_ref .contiguous ().to (dtype ), atol = 1e-3 , rtol = 1e-2 )
3880
+
3881
+ def test_onednn_attention_fail_d256 (self , device ):
3882
+ # Test that onednn graph attention dispatching correctly bails out on d > 256
3883
+ b , h = 1 , 2
3884
+ s_q , s_kv = 128 , 128
3885
+ d_qk , d_v = 512 , 512
3886
+
3887
+ q = torch .randn (b , h , s_q , d_qk , device = device , dtype = torch .bfloat16 )
3888
+ k = torch .randn (b , h , s_kv , d_qk , device = device , dtype = torch .bfloat16 )
3889
+ v = torch .randn (b , h , s_kv , d_v , device = device , dtype = torch .bfloat16 )
3890
+
3891
+ with sdpa_kernel (backends = [SDPBackend .OVERRIDEABLE ]):
3892
+ with self .assertRaisesRegex (RuntimeError , "No available kernel." ):
3893
+ _ = F .scaled_dot_product_attention (q , k , v )
3894
+
3895
+ @parametrize ("type" , ["dense" ])
3896
+ @parametrize ("is_contiguous" , [True , False ])
3897
+ def test_scaled_dot_product_attention_fused_kernels_packed (self , device , type : str , is_contiguous : bool ):
3898
+ make_tensor = partial (rand_sdpa_tensor , type = type , device = device , dtype = torch .float16 , packed = True )
3899
+
3900
+ batch_size , seq_len , num_heads , head_dim = 32 , 64 , 16 , 64
3901
+ shape = SdpaShape (batch_size , num_heads , seq_len , head_dim )
3902
+
3903
+ # Test Packed
3904
+ qkv = make_tensor (shape )
3905
+ query , key , value = qkv .chunk (3 , dim = - 1 )
3906
+
3907
+ query = query .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
3908
+ value = value .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
3909
+ key = key .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
3910
+
3911
+ if is_contiguous :
3912
+ query = query .contiguous ()
3913
+ key = key .contiguous ()
3914
+ value = value .contiguous ()
3915
+
3916
+ with sdpa_kernel (backends = [SDPBackend .OVERRIDEABLE ]):
3917
+ actual = torch .nn .functional .scaled_dot_product_attention (
3918
+ query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False )
3919
+ math_ref = torch .ops .aten ._scaled_dot_product_attention_math (
3920
+ query .contiguous (), key .contiguous (), value .contiguous (), attn_mask = None , dropout_p = 0.0 , is_causal = False )[0 ]
3921
+
3922
+ self .assertEqual (actual .contiguous (), math_ref .contiguous (), atol = 2e-3 , rtol = 1e-2 )
3923
+
3924
+ @parametrize ("fused_kernel" , [SDPBackend .MATH , SDPBackend .OVERRIDEABLE ])
3925
+ @parametrize ("dtype" , [torch .half , torch .bfloat16 , torch .float32 ])
3926
+ @parametrize ("batch_size,n_head,q_size,kv_size,head_dim" , [
3927
+ (2 , 5 , 9216 , 9216 , 64 ),
3928
+ (2 , 5 , 9216 , 77 , 64 ),
3929
+ (2 , 10 , 2304 , 2304 , 64 ),
3930
+ (2 , 10 , 2304 , 77 , 64 ),
3931
+ (2 , 20 , 576 , 576 , 64 ),
3932
+ (2 , 20 , 576 , 77 , 64 ),
3933
+ (2 , 20 , 144 , 144 , 64 ),
3934
+ (2 , 20 , 144 , 77 , 64 ),
3935
+ (1 , 32 , 1 , 32 , 128 ),
3936
+ (4 , 32 , 1 , 32 , 128 ),
3937
+ (1 , 32 , 32 , 32 , 128 ),
3938
+ (4 , 32 , 32 , 32 , 128 ),
3939
+ (1 , 32 , 2016 , 2016 , 128 ),
3940
+ (4 , 32 , 2016 , 2016 , 128 ),
3941
+ ])
3942
+ @parametrize ("mask_type" , ["float" , "causal" ])
3943
+ @parametrize ("train" , [False ])
3944
+ def test_scaled_dot_product_fused_attention_mask_vs_math (
3945
+ self ,
3946
+ device ,
3947
+ fused_kernel ,
3948
+ dtype ,
3949
+ batch_size ,
3950
+ q_size ,
3951
+ kv_size ,
3952
+ n_head ,
3953
+ head_dim ,
3954
+ mask_type ,
3955
+ train ,
3956
+ ):
3957
+ # Migrate from TestSDPACpuOnly
3958
+ tol = Tolerances (1e-5 , 5e-6 )
3959
+ if dtype is torch .bfloat16 :
3960
+ tol = Tolerances (5e-2 , 5e-2 )
3961
+ if dtype is torch .float16 :
3962
+ tol = Tolerances (1e-2 , 1e-2 )
3963
+ mask_shape = [batch_size , 1 , 1 , kv_size ]
3964
+ make_tensor = partial (rand_sdpa_tensor , type = "dense" , device = device , dtype = dtype , requires_grad = False )
3965
+ q_shape = SdpaShape (batch_size , n_head , q_size , head_dim )
3966
+ kv_shape = SdpaShape (batch_size , n_head , kv_size , head_dim )
3967
+ q = make_tensor (q_shape )
3968
+ k = make_tensor (kv_shape )
3969
+ v = make_tensor (kv_shape )
3970
+ q2 , k2 , v2 = q .clone (), k .clone (), v .clone ()
3971
+
3972
+ if train :
3973
+ q .requires_grad_ (True )
3974
+ k .requires_grad_ (True )
3975
+ v .requires_grad_ (True )
3976
+ q2 .requires_grad_ (True )
3977
+ k2 .requires_grad_ (True )
3978
+ v2 .requires_grad_ (True )
3979
+
3980
+ # (B, nh, T, hs)
3981
+ q = q .view (batch_size , q_size , n_head , head_dim ).transpose (1 , 2 )
3982
+ k = k .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
3983
+ v = v .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
3984
+ attn_mask = None
3985
+ is_causal = False
3986
+ if mask_type == "bool" :
3987
+ attn_mask = torch .randint (0 , 2 , size = mask_shape , dtype = torch .bool , device = device )
3988
+ elif mask_type == "float" :
3989
+ attn_mask = torch .randn (mask_shape , dtype = dtype , device = device )
3990
+ elif mask_type == "causal" :
3991
+ is_causal = True
3992
+
3993
+ q2 , k2 , v2 = q2 .float (), k2 .float (), v2 .float ()
3994
+ q2 = q2 .view (batch_size , q_size , n_head , head_dim ).transpose (1 , 2 )
3995
+ k2 = k2 .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
3996
+ v2 = v2 .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
3997
+ attn_mask2 = attn_mask .float () if attn_mask is not None else None
3998
+
3999
+ if fused_kernel == SDPBackend .MATH :
4000
+ actual = torch .ops .aten ._scaled_dot_product_attention_math (
4001
+ q , k , v , attn_mask = attn_mask , dropout_p = 0.0 , is_causal = is_causal )[0 ]
4002
+ elif fused_kernel == SDPBackend .OVERRIDEABLE :
4003
+ actual = torch .ops .aten ._scaled_dot_product_fused_attention_overrideable (
4004
+ q , k , v , attn_bias = attn_mask , dropout_p = 0.0 , is_causal = is_causal )[0 ]
4005
+
4006
+ math_ref = torch .ops .aten ._scaled_dot_product_attention_math (
4007
+ q2 , k2 , v2 , attn_mask = attn_mask2 , dropout_p = 0.0 , is_causal = is_causal )[0 ]
4008
+
4009
+ self .assertEqual (actual .float (), math_ref , atol = tol .atol , rtol = tol .rtol )
4010
+
3841
4011
3842
4012
class TestAttnBias (NNTestCase ):
3843
4013
@@ -4014,6 +4184,7 @@ def test_is_causal_and_mask_fails(self, device):
4014
4184
instantiate_device_type_tests (TestSDPACudaOnly , globals (), only_for = ("cuda" ))
4015
4185
instantiate_device_type_tests (TestSDPACpuOnly , globals (), only_for = ("cpu" ))
4016
4186
instantiate_device_type_tests (TestAttnBias , globals (), only_for = device_types )
4187
+ instantiate_device_type_tests (TestSDPAXpuOnly , globals (), only_for = "xpu" , allow_xpu = True )
4017
4188
4018
4189
if __name__ == '__main__' :
4019
4190
run_tests ()
0 commit comments