@@ -3841,6 +3841,176 @@ def rand_nt(sequence_list, num_heads, head_dim):
3841
3841
}
3842
3842
)
3843
3843
3844
+ class TestSDPAXpuOnly (NNTestCase ):
3845
+ """ Used to test XPU only functionality of scaled_dot_product_attention
3846
+ Mostly migrate from TestSDPACudaOnly in test/test_transformers.py
3847
+
3848
+ Note that as SDPBackend.OVERRIDEABLE is not managed by sdpa_kernel so that
3849
+ math ref has to be called explicitly via torch.ops.aten._scaled_dot_product_attention_math.
3850
+ """
3851
+
3852
+ @parametrize ("type" , ["dense" ])
3853
+ @parametrize ("dropout" , [0.0 , 0.7 ])
3854
+ @parametrize ("dtype" , [torch .float64 , torch .float32 , torch .bfloat16 , torch .half ])
3855
+ @skipIfTorchDynamo ()
3856
+ def test_fused_sdp_choice_xpu (self , device , type : str , dropout : float , dtype : torch .dtype ):
3857
+ # Migrate from test_fused_sdp_choice_cpu
3858
+ make_tensor = partial (rand_sdpa_tensor , type = type , device = device , dtype = dtype )
3859
+ size = SdpaShape (2 , 8 , 128 , 64 )
3860
+ q , k , v = make_tensor (size ), make_tensor (size ), make_tensor (size )
3861
+ if dropout > 0.0 or dtype not in [torch .float32 , torch .bfloat16 , torch .float16 ]:
3862
+ assert torch ._fused_sdp_choice (q , k , v , dropout_p = dropout ) == SDPBackend .MATH .value
3863
+ else :
3864
+ assert torch ._fused_sdp_choice (q , k , v , dropout_p = dropout ) == SDPBackend .OVERRIDEABLE .value
3865
+
3866
+ def test_fused_attention_different_dk_dv (self , device ):
3867
+ dtype = torch .bfloat16
3868
+ make_tensor = partial (torch .rand , device = device , dtype = dtype , requires_grad = True )
3869
+ batch , num_heads , head_dim_k , head_dim_v = 32 , 16 , 128 , 64
3870
+ q_shape = SdpaShape (batch , num_heads , 1 , head_dim_k )
3871
+ k_shape = SdpaShape (batch , num_heads , 2 , head_dim_k )
3872
+ v_shape = SdpaShape (batch , num_heads , 2 , head_dim_v )
3873
+ query , key , value = make_tensor (q_shape ), make_tensor (k_shape ), make_tensor (v_shape )
3874
+
3875
+ # test that we do not dispatch to onednn for an unsupported case
3876
+ actual = F .scaled_dot_product_attention (
3877
+ query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False )
3878
+
3879
+ math_ref = torch .ops .aten ._scaled_dot_product_attention_math (
3880
+ query .float (), key .float (), value .float (), attn_mask = None , dropout_p = 0.0 , is_causal = False )[0 ]
3881
+
3882
+ self .assertEqual (actual .contiguous (), math_ref .contiguous ().to (dtype ), atol = 1e-3 , rtol = 1e-2 )
3883
+
3884
+ def test_onednn_attention_fail_d256 (self , device ):
3885
+ # Test that onednn graph attention dispatching correctly bails out on d > 256
3886
+ b , h = 1 , 2
3887
+ s_q , s_kv = 128 , 128
3888
+ d_qk , d_v = 512 , 512
3889
+
3890
+ q = torch .randn (b , h , s_q , d_qk , device = device , dtype = torch .bfloat16 )
3891
+ k = torch .randn (b , h , s_kv , d_qk , device = device , dtype = torch .bfloat16 )
3892
+ v = torch .randn (b , h , s_kv , d_v , device = device , dtype = torch .bfloat16 )
3893
+
3894
+ with sdpa_kernel (backends = [SDPBackend .OVERRIDEABLE ]):
3895
+ with self .assertRaisesRegex (RuntimeError , "No available kernel." ):
3896
+ _ = F .scaled_dot_product_attention (q , k , v )
3897
+
3898
+ @parametrize ("type" , ["dense" ])
3899
+ @parametrize ("is_contiguous" , [True , False ])
3900
+ def test_scaled_dot_product_attention_fused_kernels_packed (self , device , type : str , is_contiguous : bool ):
3901
+ make_tensor = partial (rand_sdpa_tensor , type = type , device = device , dtype = torch .float16 , packed = True )
3902
+
3903
+ batch_size , seq_len , num_heads , head_dim = 32 , 64 , 16 , 64
3904
+ shape = SdpaShape (batch_size , num_heads , seq_len , head_dim )
3905
+
3906
+ # Test Packed
3907
+ qkv = make_tensor (shape )
3908
+ query , key , value = qkv .chunk (3 , dim = - 1 )
3909
+
3910
+ query = query .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
3911
+ value = value .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
3912
+ key = key .view (batch_size , - 1 , num_heads , head_dim ).transpose (1 , 2 )
3913
+
3914
+ if is_contiguous :
3915
+ query = query .contiguous ()
3916
+ key = key .contiguous ()
3917
+ value = value .contiguous ()
3918
+
3919
+ with sdpa_kernel (backends = [SDPBackend .OVERRIDEABLE ]):
3920
+ actual = torch .nn .functional .scaled_dot_product_attention (
3921
+ query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False )
3922
+ math_ref = torch .ops .aten ._scaled_dot_product_attention_math (
3923
+ query .contiguous (), key .contiguous (), value .contiguous (), attn_mask = None , dropout_p = 0.0 , is_causal = False )[0 ]
3924
+
3925
+ self .assertEqual (actual .contiguous (), math_ref .contiguous (), atol = 2e-3 , rtol = 1e-2 )
3926
+
3927
+ @parametrize ("fused_kernel" , [SDPBackend .MATH , SDPBackend .OVERRIDEABLE ])
3928
+ @parametrize ("dtype" , [torch .half , torch .bfloat16 , torch .float32 ])
3929
+ @parametrize ("batch_size,n_head,q_size,kv_size,head_dim" , [
3930
+ (2 , 5 , 9216 , 9216 , 64 ),
3931
+ (2 , 5 , 9216 , 77 , 64 ),
3932
+ (2 , 10 , 2304 , 2304 , 64 ),
3933
+ (2 , 10 , 2304 , 77 , 64 ),
3934
+ (2 , 20 , 576 , 576 , 64 ),
3935
+ (2 , 20 , 576 , 77 , 64 ),
3936
+ (2 , 20 , 144 , 144 , 64 ),
3937
+ (2 , 20 , 144 , 77 , 64 ),
3938
+ (1 , 32 , 1 , 32 , 128 ),
3939
+ (4 , 32 , 1 , 32 , 128 ),
3940
+ (1 , 32 , 32 , 32 , 128 ),
3941
+ (4 , 32 , 32 , 32 , 128 ),
3942
+ (1 , 32 , 2016 , 2016 , 128 ),
3943
+ (4 , 32 , 2016 , 2016 , 128 ),
3944
+ ])
3945
+ @parametrize ("mask_type" , ["float" , "causal" ])
3946
+ @parametrize ("train" , [False ])
3947
+ def test_scaled_dot_product_fused_attention_mask_vs_math (
3948
+ self ,
3949
+ device ,
3950
+ fused_kernel ,
3951
+ dtype ,
3952
+ batch_size ,
3953
+ q_size ,
3954
+ kv_size ,
3955
+ n_head ,
3956
+ head_dim ,
3957
+ mask_type ,
3958
+ train ,
3959
+ ):
3960
+ # Migrate from TestSDPACpuOnly
3961
+ tol = Tolerances (1e-5 , 5e-6 )
3962
+ if dtype is torch .bfloat16 :
3963
+ tol = Tolerances (5e-2 , 5e-2 )
3964
+ if dtype is torch .float16 :
3965
+ tol = Tolerances (1e-2 , 1e-2 )
3966
+ mask_shape = [batch_size , 1 , 1 , kv_size ]
3967
+ make_tensor = partial (rand_sdpa_tensor , type = "dense" , device = device , dtype = dtype , requires_grad = False )
3968
+ q_shape = SdpaShape (batch_size , n_head , q_size , head_dim )
3969
+ kv_shape = SdpaShape (batch_size , n_head , kv_size , head_dim )
3970
+ q = make_tensor (q_shape )
3971
+ k = make_tensor (kv_shape )
3972
+ v = make_tensor (kv_shape )
3973
+ q2 , k2 , v2 = q .clone (), k .clone (), v .clone ()
3974
+
3975
+ if train :
3976
+ q .requires_grad_ (True )
3977
+ k .requires_grad_ (True )
3978
+ v .requires_grad_ (True )
3979
+ q2 .requires_grad_ (True )
3980
+ k2 .requires_grad_ (True )
3981
+ v2 .requires_grad_ (True )
3982
+
3983
+ # (B, nh, T, hs)
3984
+ q = q .view (batch_size , q_size , n_head , head_dim ).transpose (1 , 2 )
3985
+ k = k .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
3986
+ v = v .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
3987
+ attn_mask = None
3988
+ is_causal = False
3989
+ if mask_type == "bool" :
3990
+ attn_mask = torch .randint (0 , 2 , size = mask_shape , dtype = torch .bool , device = device )
3991
+ elif mask_type == "float" :
3992
+ attn_mask = torch .randn (mask_shape , dtype = dtype , device = device )
3993
+ elif mask_type == "causal" :
3994
+ is_causal = True
3995
+
3996
+ q2 , k2 , v2 = q2 .float (), k2 .float (), v2 .float ()
3997
+ q2 = q2 .view (batch_size , q_size , n_head , head_dim ).transpose (1 , 2 )
3998
+ k2 = k2 .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
3999
+ v2 = v2 .view (batch_size , kv_size , n_head , head_dim ).transpose (1 , 2 )
4000
+ attn_mask2 = attn_mask .float () if attn_mask is not None else None
4001
+
4002
+ if fused_kernel == SDPBackend .MATH :
4003
+ actual = torch .ops .aten ._scaled_dot_product_attention_math (
4004
+ q , k , v , attn_mask = attn_mask , dropout_p = 0.0 , is_causal = is_causal )[0 ]
4005
+ elif fused_kernel == SDPBackend .OVERRIDEABLE :
4006
+ actual = torch .ops .aten ._scaled_dot_product_fused_attention_overrideable (
4007
+ q , k , v , attn_bias = attn_mask , dropout_p = 0.0 , is_causal = is_causal )[0 ]
4008
+
4009
+ math_ref = torch .ops .aten ._scaled_dot_product_attention_math (
4010
+ q2 , k2 , v2 , attn_mask = attn_mask2 , dropout_p = 0.0 , is_causal = is_causal )[0 ]
4011
+
4012
+ self .assertEqual (actual .float (), math_ref , atol = tol .atol , rtol = tol .rtol )
4013
+
3844
4014
3845
4015
class TestAttnBias (NNTestCase ):
3846
4016
@@ -4080,6 +4250,7 @@ def test_scaled_dot_product_fused_attention_overrideable_backward(self):
4080
4250
instantiate_device_type_tests (TestSDPACudaOnly , globals (), only_for = ("cuda" ))
4081
4251
instantiate_device_type_tests (TestSDPACpuOnly , globals (), only_for = ("cpu" ))
4082
4252
instantiate_device_type_tests (TestAttnBias , globals (), only_for = device_types )
4253
+ instantiate_device_type_tests (TestSDPAXpuOnly , globals (), only_for = "xpu" , allow_xpu = True )
4083
4254
4084
4255
if __name__ == '__main__' :
4085
4256
run_tests ()
0 commit comments