@@ -2539,9 +2539,7 @@ def mask(b, h, q, kv):
2539
2539
k .grad = None
2540
2540
v .grad = None
2541
2541
2542
- block_mask2 = create_block_mask (
2543
- mask , None , None , 2048 , 2048 , device = device
2544
- )
2542
+ block_mask2 = create_block_mask (mask , None , None , 2048 , 2048 , device = device )
2545
2543
# Reuse the 1st version with q/k/v(seqlen=2048) and block_mask(seqlen=2048)
2546
2544
out2 = torch .compile (flex_attention , dynamic = True )(
2547
2545
q , k , v , block_mask = block_mask2
@@ -2732,7 +2730,9 @@ def test_flex_attention_stride_ordering(self, device, mode, permute_order, shape
2732
2730
],
2733
2731
)
2734
2732
@common_utils .parametrize ("shape" , [(2 , 5 , 128 , 16 ), (4 , 2 , 64 , 16 )])
2735
- def test_flex_attention_backward_stride_ordering (self , device , mode , permute_order , shape ):
2733
+ def test_flex_attention_backward_stride_ordering (
2734
+ self , device , mode , permute_order , shape
2735
+ ):
2736
2736
from torch ._inductor .ir import get_stride_order
2737
2737
2738
2738
dtype = torch .float32
@@ -2865,15 +2865,11 @@ def causal_mask(b, h, q_idx, kv_idx):
2865
2865
score_mod_sparse_flex = functools .partial (
2866
2866
flex_attention ,
2867
2867
score_mod = causal ,
2868
- block_mask = create_block_mask (
2869
- causal_mask , 1 , 1 , 2048 , 2048 , device = device
2870
- ),
2868
+ block_mask = create_block_mask (causal_mask , 1 , 1 , 2048 , 2048 , device = device ),
2871
2869
)
2872
2870
mask_mod_sparse_flex = functools .partial (
2873
2871
flex_attention ,
2874
- block_mask = create_block_mask (
2875
- causal_mask , 1 , 1 , 2048 , 2048 , device = device
2876
- ),
2872
+ block_mask = create_block_mask (causal_mask , 1 , 1 , 2048 , 2048 , device = device ),
2877
2873
)
2878
2874
for attention_call in [
2879
2875
no_sparse_flex ,
@@ -2892,9 +2888,7 @@ def causal_mask(b, h, q_idx, kv_idx):
2892
2888
)
2893
2889
for _ in range (3 )
2894
2890
]
2895
- gradOut = torch .randn (
2896
- 2 , 2 , 2048 , 64 , device = device , dtype = torch .float16
2897
- )
2891
+ gradOut = torch .randn (2 , 2 , 2048 , 64 , device = device , dtype = torch .float16 )
2898
2892
out_ref = torch .nn .functional .scaled_dot_product_attention (
2899
2893
* inputs , is_causal = True
2900
2894
)
@@ -2949,9 +2943,7 @@ def mod(b, h, q, kv):
2949
2943
@unittest .skipIf (SKIP_UT_ON_CPU , "Skip on CPU as not supported" )
2950
2944
def test_head_bias_req_grad (self , device ):
2951
2945
B , H , S , D = 1 , 4 , 256 , 64
2952
- bias = torch .randn (
2953
- H , device = device , dtype = torch .float16 , requires_grad = True
2954
- )
2946
+ bias = torch .randn (H , device = device , dtype = torch .float16 , requires_grad = True )
2955
2947
2956
2948
bias_flex = bias .detach ().clone ().requires_grad_ (True )
2957
2949
@@ -3017,9 +3009,7 @@ def rel_pos_1d(score, b, h, q_idx, kv_idx):
3017
3009
3018
3010
# 2-dimensional bias:
3019
3011
B , H , S , D = 1 , 1 , 256 , 64
3020
- bias = torch .randn (
3021
- S , S , device = device , dtype = torch .float16 , requires_grad = True
3022
- )
3012
+ bias = torch .randn (S , S , device = device , dtype = torch .float16 , requires_grad = True )
3023
3013
3024
3014
bias_flex = bias .detach ().clone ().requires_grad_ (True )
3025
3015
@@ -3048,9 +3038,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
3048
3038
3049
3039
# 2-dimensional bias + index multiple
3050
3040
B , H , S , D = 1 , 1 , 256 , 64
3051
- bias = torch .randn (
3052
- S , S , device = device , dtype = torch .float16 , requires_grad = True
3053
- )
3041
+ bias = torch .randn (S , S , device = device , dtype = torch .float16 , requires_grad = True )
3054
3042
3055
3043
bias_flex = bias .detach ().clone ().requires_grad_ (True )
3056
3044
@@ -3079,9 +3067,7 @@ def rel_pos_2d(score, b, h, q_idx, kv_idx):
3079
3067
3080
3068
# 2-dimensional bias + transposed:
3081
3069
B , H , S , D = 1 , 1 , 256 , 64
3082
- bias = torch .randn (
3083
- S , S , device = device , dtype = torch .float16 , requires_grad = True
3084
- )
3070
+ bias = torch .randn (S , S , device = device , dtype = torch .float16 , requires_grad = True )
3085
3071
3086
3072
bias_flex = bias .detach ().clone ().requires_grad_ (True )
3087
3073
@@ -3868,7 +3854,7 @@ def forward(self, arg0_1: "i32[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
3868
3854
joint_graph ,
3869
3855
expected_joint_graph ,
3870
3856
)
3871
-
3857
+
3872
3858
@supported_platform
3873
3859
@unittest .skipIf (TEST_ON_CUDA , "Testing CPU error message" )
3874
3860
def test_cpu_error_message_return_lse (self , device ):
@@ -3992,9 +3978,7 @@ def test_block_mask_attributes(self, device):
3992
3978
def causal_mask (b , h , q , kv ):
3993
3979
return (q + (offset [b ] * 128 )) >= kv
3994
3980
3995
- block_mask = create_block_mask (
3996
- causal_mask , 4 , 2 , 2048 , 2048 , device = device
3997
- )
3981
+ block_mask = create_block_mask (causal_mask , 4 , 2 , 2048 , 2048 , device = device )
3998
3982
self .assertEqual (block_mask .shape , (4 , 2 , 2048 , 2048 ))
3999
3983
self .assertEqual (block_mask [0 ].shape , (2 , 2048 , 2048 ))
4000
3984
self .assertEqual (block_mask [0 , 0 ].shape , (2048 , 2048 ))
@@ -4005,9 +3989,7 @@ def causal_mask(b, h, q, kv):
4005
3989
self .assertEqual (block_mask .sparsity (), block_mask [1 ].sparsity ())
4006
3990
4007
3991
offset = torch .arange (8 , device = device )
4008
- block_mask = create_block_mask (
4009
- causal_mask , 8 , 1 , 2048 , 2048 , device = device
4010
- )
3992
+ block_mask = create_block_mask (causal_mask , 8 , 1 , 2048 , 2048 , device = device )
4011
3993
self .assertEqual (block_mask .sparsity (), 29.1015625 )
4012
3994
self .assertTrue (block_mask .sparsity () < block_mask [0 ].sparsity ())
4013
3995
self .assertTrue (block_mask [0 ].sparsity () > block_mask [1 ].sparsity ())
@@ -4163,9 +4145,7 @@ def test_block_mask_viz(self, device):
4163
4145
def causal_mask (b , h , q , kv ):
4164
4146
return q >= kv
4165
4147
4166
- block_mask = create_block_mask (
4167
- causal_mask , 1 , 1 , 2048 , 2048 , device = device
4168
- )
4148
+ block_mask = create_block_mask (causal_mask , 1 , 1 , 2048 , 2048 , device = device )
4169
4149
4170
4150
def replace_non_printable (s ):
4171
4151
def replace (c ):
@@ -4368,9 +4348,7 @@ def test_no_q_info(self, device, compile: bool):
4368
4348
def causal_mask (b , h , q_idx , kv_idx ):
4369
4349
return q_idx >= kv_idx
4370
4350
4371
- block_mask = create_block_mask (
4372
- causal_mask , 1 , 1 , 2048 , 2048 , device = device
4373
- )
4351
+ block_mask = create_block_mask (causal_mask , 1 , 1 , 2048 , 2048 , device = device )
4374
4352
# manually set q_num_blocks and q_indices to None
4375
4353
block_mask .q_num_blocks = None
4376
4354
block_mask .q_indices = None
@@ -4494,9 +4472,7 @@ def test_eager_tracing_correctness(self, device):
4494
4472
seq_len = 256
4495
4473
batch_size = 1
4496
4474
4497
- make_tensor = functools .partial (
4498
- torch .randn , device = device , dtype = torch .float16
4499
- )
4475
+ make_tensor = functools .partial (torch .randn , device = device , dtype = torch .float16 )
4500
4476
q = make_tensor (* (batch_size , q_heads , seq_len , qk_dims ))
4501
4477
k = make_tensor (* (batch_size , kv_heads , seq_len , qk_dims ))
4502
4478
v = make_tensor (* (batch_size , kv_heads , seq_len , v_dims ))
@@ -4556,16 +4532,12 @@ def create_inputs(S):
4556
4532
)
4557
4533
return q , k , v
4558
4534
4559
- block_mask = create_block_mask (
4560
- mask_mod , None , None , 1024 , 1024 , device = device
4561
- )
4535
+ block_mask = create_block_mask (mask_mod , None , None , 1024 , 1024 , device = device )
4562
4536
flex_attention_call (* create_inputs (1024 ), block_mask = block_mask )
4563
4537
with self .assertRaisesRegex (ValueError , "block_mask was created for" ):
4564
4538
flex_attention_call (* create_inputs (2048 ), block_mask = block_mask )
4565
4539
4566
- block_mask = create_block_mask (
4567
- mask_mod , None , None , 1023 , 1023 , device = device
4568
- )
4540
+ block_mask = create_block_mask (mask_mod , None , None , 1023 , 1023 , device = device )
4569
4541
with self .assertRaisesRegex (ValueError , "block_mask was created for" ):
4570
4542
flex_attention_call (* create_inputs (1024 ), block_mask = block_mask )
4571
4543
@@ -5634,10 +5606,19 @@ def sliding_window(b, h, q_idx, kv_idx, val):
5634
5606
opt_fn (sliding_window2 , None , None , 1024 , 1024 )
5635
5607
5636
5608
5637
- instantiate_device_type_tests (TestFlexAttention , globals (), only_for = test_device , allow_xpu = True )
5638
- instantiate_device_type_tests (TestPagedAttention , globals (), only_for = test_device , allow_xpu = True )
5639
- instantiate_device_type_tests (TestBlockMask , globals (), only_for = test_device , allow_xpu = True )
5640
- instantiate_device_type_tests (TestLearnableBiases , globals (), only_for = test_device , allow_xpu = True )
5609
+ instantiate_device_type_tests (
5610
+ TestFlexAttention , globals (), only_for = test_device , allow_xpu = True
5611
+ )
5612
+ instantiate_device_type_tests (
5613
+ TestPagedAttention , globals (), only_for = test_device , allow_xpu = True
5614
+ )
5615
+ instantiate_device_type_tests (
5616
+ TestBlockMask , globals (), only_for = test_device , allow_xpu = True
5617
+ )
5618
+ instantiate_device_type_tests (
5619
+ TestLearnableBiases , globals (), only_for = test_device , allow_xpu = True
5620
+ )
5621
+
5641
5622
5642
5623
if __name__ == "__main__" :
5643
5624
from torch ._inductor .test_case import run_tests
0 commit comments