@@ -2658,6 +2658,7 @@ def test_strided_backwards(self):
2658
2658
(1 , 0 , 2 , 3 ), # Reverse order
2659
2659
(0 , 2 , 1 , 3 ), # Mixed order
2660
2660
(2 , 0 , 1 , 3 ), # Another mixed order
2661
+ (0 , 1 , 3 , 2 ) # Non contiguous last dim
2661
2662
],
2662
2663
)
2663
2664
@common_utils .parametrize ("shape" , [(2 , 1 , 128 , 16 ), (4 , 2 , 64 , 16 )])
@@ -2712,6 +2713,7 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
2712
2713
(1 , 0 , 2 , 3 ),
2713
2714
(0 , 2 , 1 , 3 ),
2714
2715
(2 , 0 , 1 , 3 ),
2716
+ (0 , 1 , 3 , 2 )
2715
2717
],
2716
2718
)
2717
2719
@common_utils .parametrize ("shape" , [(2 , 5 , 128 , 16 ), (4 , 2 , 64 , 16 )])
@@ -2754,6 +2756,75 @@ def test_flex_attention_backward_stride_ordering(self, mode, permute_order, shap
2754
2756
f"Mode: { mode } , Stride order mismatch for { name } : grad { input_stride_order } , input { orig_stride_order } ." ,
2755
2757
)
2756
2758
2759
+ @supported_platform
2760
+ def test_non_contiguous_last_dim (self , device ):
2761
+ """Test flex_attention with tensors having non contiguous last dimension."""
2762
+ B , H , S , D = 4 , 8 , 128 , 64
2763
+ dtype = torch .float16 if device == "cuda" else torch .float32
2764
+
2765
+ def create_non_unit_stride_tensor ():
2766
+ tensor = torch .randn (
2767
+ (B , H , S , D ),
2768
+ dtype = dtype ,
2769
+ device = device ,
2770
+ )
2771
+ # Column major in last 2 dims
2772
+ return tensor .transpose (- 1 , - 2 ).contiguous ().transpose (- 1 , - 2 )
2773
+
2774
+ # Create tensors with non-unit stride
2775
+ q = create_non_unit_stride_tensor ()
2776
+ k = create_non_unit_stride_tensor ()
2777
+ v = create_non_unit_stride_tensor ()
2778
+
2779
+ if not self .test_inference_only :
2780
+ q .requires_grad_ (True )
2781
+ k .requires_grad_ (True )
2782
+ v .requires_grad_ (True )
2783
+
2784
+ # Verify last dimension has non-unit stride
2785
+ self .assertNotEqual (q .stride ()[- 1 ], 1 )
2786
+ self .assertNotEqual (k .stride ()[- 1 ], 1 )
2787
+ self .assertNotEqual (v .stride ()[- 1 ], 1 )
2788
+
2789
+ # Create clones for different computation paths
2790
+ q_ref , k_ref , v_ref = query_key_value_clones (q , k , v )
2791
+ q_gold , k_gold , v_gold = query_key_value_clones (q , k , v , torch .float64 )
2792
+
2793
+ # Run with different precisions and compilation
2794
+ golden_out = flex_attention (q_gold , k_gold , v_gold )
2795
+ ref_out = flex_attention (q_ref , k_ref , v_ref )
2796
+
2797
+ flex_compiled = torch .compile (flex_attention , fullgraph = True )
2798
+ compiled_out = flex_compiled (q , k , v )
2799
+
2800
+ # Check forward pass correctness
2801
+ print (compiled_out )
2802
+ self ._check_out (golden_out , ref_out , compiled_out )
2803
+
2804
+ if not self .test_inference_only :
2805
+ # For backward pass testing
2806
+ backward_grad = torch .randn_like (ref_out )
2807
+
2808
+ golden_out .backward (backward_grad .to (torch .float64 ))
2809
+ ref_out .backward (backward_grad )
2810
+ compiled_out .backward (backward_grad )
2811
+
2812
+ # Check backward pass correctness
2813
+ self ._check_out_and_grad (
2814
+ golden_out ,
2815
+ ref_out ,
2816
+ compiled_out ,
2817
+ q_gold ,
2818
+ q_ref ,
2819
+ q ,
2820
+ k_gold ,
2821
+ k_ref ,
2822
+ k ,
2823
+ v_gold ,
2824
+ v_ref ,
2825
+ v ,
2826
+ )
2827
+
2757
2828
@supported_platform
2758
2829
@common_utils .parametrize ("compile" , [True , False ])
2759
2830
def test_fully_masked_out_rows_0_check (self , device , compile : bool ):
0 commit comments