@@ -105,7 +105,6 @@ def _check_common(
105
105
):
106
106
self .assertEqual (arg1 .grad , arg2 .grad , atol = atol , rtol = rtol )
107
107
108
- @skipIfRocm
109
108
def _test_sdpa_rewriter_1 (self ):
110
109
def dot_prod_attention (
111
110
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -132,7 +131,6 @@ def dot_prod_attention(
132
131
rtol = rtol ,
133
132
)
134
133
135
- @skipIfRocm
136
134
@torch ._inductor .config .patch ("freezing" , True )
137
135
def _test_sdpa_rewriter_1_freezing (self ):
138
136
def dot_prod_attention (
@@ -161,7 +159,6 @@ def dot_prod_attention(
161
159
check_train = False ,
162
160
)
163
161
164
- @skipIfRocm # https://github.com/pytorch/pytorch/issues/146848
165
162
def _test_insignificant_strides (self ):
166
163
f32 = torch .float32
167
164
@@ -265,7 +262,6 @@ def dot_prod_attention(
265
262
_ , (source_code ,) = run_and_get_code (dot_prod_attention , * args )
266
263
self .assertNotIn ("aten._scaled_dot_product_efficient_attention" , source_code )
267
264
268
- @skipIfRocm
269
265
def _test_sdpa_rewriter_2 (self ):
270
266
def dot_prod_attention (
271
267
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -280,7 +276,6 @@ def dot_prod_attention(
280
276
self ._check_common (dot_prod_attention )
281
277
self ._check_common (checkpoint_wrapper (dot_prod_attention ))
282
278
283
- @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
284
279
def _test_sdpa_rewriter_3 (self ):
285
280
def dot_prod_attention (
286
281
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , training : bool
@@ -297,7 +292,6 @@ def dot_prod_attention(
297
292
checkpoint_wrapper (dot_prod_attention ), contains = False , has_dropout = True
298
293
)
299
294
300
- @skipIfRocm # AssertionError: expected size 4==4, stride 32==64 at dim=0
301
295
def _test_sdpa_rewriter_4 (self ):
302
296
def dot_prod_attention (
303
297
query : torch .Tensor ,
@@ -347,7 +341,6 @@ def sfdp_pattern_5_v2(query, key, value):
347
341
self ._check_common (sfdp_pattern_5_v2 , contains = False )
348
342
self ._check_common (checkpoint_wrapper (sfdp_pattern_5_v2 ), contains = False )
349
343
350
- @skipIfRocm
351
344
def _test_sdpa_rewriter_6 (self ):
352
345
def sfdp_pattern_6 (query , key , value , training ):
353
346
attn_mask = torch .ones (
@@ -571,7 +564,6 @@ def forward(self, query, key, value, attn_mask) -> torch.Tensor:
571
564
model , args1 = args , contains = False , atol = 1e-4 , has_fuse_pattern = False
572
565
)
573
566
574
- @skipIfRocm
575
567
def _test_sdpa_rewriter_11 (self ):
576
568
def dot_prod_attention (
577
569
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -612,7 +604,6 @@ def dot_prod_attention(
612
604
613
605
self ._check_common (dot_prod_attention , contains = False , has_dropout = True )
614
606
615
- @skipIfRocm
616
607
def _test_sdpa_prev_13 (self ):
617
608
def dot_prod_attention (
618
609
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -629,7 +620,6 @@ def dot_prod_attention(
629
620
self ._check_common (dot_prod_attention , check_train = False )
630
621
self ._check_common (checkpoint_wrapper (dot_prod_attention ), check_train = False )
631
622
632
- @skipIfRocm
633
623
def _test_sdpa_prev_14 (self ):
634
624
def dot_prod_attention (
635
625
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -645,7 +635,6 @@ def dot_prod_attention(
645
635
self ._check_common (dot_prod_attention , check_train = False )
646
636
self ._check_common (checkpoint_wrapper (dot_prod_attention ), check_train = False )
647
637
648
- @skipIfRocm
649
638
def _test_sdpa_prev_15 (self ):
650
639
def dot_prod_attention (
651
640
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -695,7 +684,6 @@ def dot_prod_attention(
695
684
rtol = 1e-2 ,
696
685
)
697
686
698
- @skipIfRocm
699
687
def _test_sdpa_rewriter_14 (self ):
700
688
def dot_prod_attention (
701
689
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -718,7 +706,6 @@ def dot_prod_attention(
718
706
719
707
self ._check_common (dot_prod_attention )
720
708
721
- @skipIfRocm
722
709
def _test_sdpa_rewriter_15 (self ):
723
710
def dot_prod_attention (
724
711
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor
@@ -811,7 +798,6 @@ def dot_prod_attention(
811
798
dot_prod_attention , args1 = args , contains = False , has_dropout = True
812
799
)
813
800
814
- @skipIfRocm
815
801
def _test_sdpa_rewriter_17 (self ):
816
802
def dot_prod_attention (
817
803
query : torch .Tensor , key : torch .Tensor , value : torch .Tensor , training
0 commit comments