@@ -2658,6 +2658,7 @@ def test_strided_backwards(self):
2658
2658
(1 , 0 , 2 , 3 ), # Reverse order
2659
2659
(0 , 2 , 1 , 3 ), # Mixed order
2660
2660
(2 , 0 , 1 , 3 ), # Another mixed order
2661
+ (0 , 1 , 3 , 2 ), # Non contiguous last dim
2661
2662
],
2662
2663
)
2663
2664
@common_utils .parametrize ("shape" , [(2 , 1 , 128 , 16 ), (4 , 2 , 64 , 16 )])
@@ -2707,12 +2708,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
2707
2708
@common_utils .parametrize ("mode" , ["eager" , "inductor" ])
2708
2709
@common_utils .parametrize (
2709
2710
"permute_order" ,
2710
- [
2711
- (0 , 1 , 2 , 3 ),
2712
- (1 , 0 , 2 , 3 ),
2713
- (0 , 2 , 1 , 3 ),
2714
- (2 , 0 , 1 , 3 ),
2715
- ],
2711
+ [(0 , 1 , 2 , 3 ), (1 , 0 , 2 , 3 ), (0 , 2 , 1 , 3 ), (2 , 0 , 1 , 3 ), (0 , 1 , 3 , 2 )],
2716
2712
)
2717
2713
@common_utils .parametrize ("shape" , [(2 , 5 , 128 , 16 ), (4 , 2 , 64 , 16 )])
2718
2714
def test_flex_attention_backward_stride_ordering (self , mode , permute_order , shape ):
@@ -2754,6 +2750,67 @@ def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shap
2754
2750
f"Mode: { mode } , Stride order mismatch for { name } : grad { input_stride_order } , input { orig_stride_order } ." ,
2755
2751
)
2756
2752
2753
+ @supported_platform
2754
+ def test_non_contiguous_last_dim (self , device ):
2755
+ """Test flex_attention with tensors having non contiguous last dimension."""
2756
+ B , H , S , D = 4 , 8 , 128 , 64
2757
+ dtype = torch .float16 if device == "cuda" else torch .float32
2758
+
2759
+ def column_major_tensor ():
2760
+ tensor = torch .randn (
2761
+ (B , H , S , D ),
2762
+ dtype = dtype ,
2763
+ device = device ,
2764
+ )
2765
+ # Column major in last 2 dims
2766
+ return tensor .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
2767
+
2768
+ q = column_major_tensor ()
2769
+ k = column_major_tensor ()
2770
+ v = column_major_tensor ()
2771
+
2772
+ if not self .test_inference_only :
2773
+ q .requires_grad_ (True )
2774
+ k .requires_grad_ (True )
2775
+ v .requires_grad_ (True )
2776
+
2777
+ self .assertNotEqual (q .stride ()[- 1 ], 1 )
2778
+ self .assertNotEqual (k .stride ()[- 1 ], 1 )
2779
+ self .assertNotEqual (v .stride ()[- 1 ], 1 )
2780
+
2781
+ q_ref , k_ref , v_ref = query_key_value_clones (q , k , v )
2782
+ q_gold , k_gold , v_gold = query_key_value_clones (q , k , v , torch .float64 )
2783
+
2784
+ golden_out = flex_attention (q_gold , k_gold , v_gold )
2785
+ ref_out = flex_attention (q_ref , k_ref , v_ref )
2786
+
2787
+ flex_compiled = torch .compile (flex_attention , fullgraph = True )
2788
+ compiled_out = flex_compiled (q , k , v )
2789
+
2790
+ self ._check_out (golden_out , ref_out , compiled_out )
2791
+
2792
+ if not self .test_inference_only :
2793
+ backward_grad = torch .randn_like (ref_out )
2794
+
2795
+ golden_out .backward (backward_grad .to (torch .float64 ))
2796
+ ref_out .backward (backward_grad )
2797
+ compiled_out .backward (backward_grad )
2798
+
2799
+ self ._check_out_and_grad (
2800
+ golden_out ,
2801
+ ref_out ,
2802
+ compiled_out ,
2803
+ q_gold ,
2804
+ q_ref ,
2805
+ q ,
2806
+ k_gold ,
2807
+ k_ref ,
2808
+ k ,
2809
+ v_gold ,
2810
+ v_ref ,
2811
+ v ,
2812
+ )
2813
+
2757
2814
@supported_platform
2758
2815
@common_utils .parametrize ("compile" , [True , False ])
2759
2816
def test_fully_masked_out_rows_0_check (self , device , compile : bool ):
0 commit comments