@@ -141,12 +141,12 @@ def mt(shape, **kwargs):
141
141
skips = (
142
142
DecorateInfo (
143
143
unittest .expectedFailure ,
144
- ' TestCommon' ,
145
- ' test_non_standard_bool_values' ,
144
+ " TestCommon" ,
145
+ " test_non_standard_bool_values" ,
146
146
dtypes = [torch .bool ],
147
- device_type = ' mps'
147
+ device_type = " mps" ,
148
148
),
149
- )
149
+ ),
150
150
),
151
151
SpectralFuncInfo (
152
152
"fft.fft2" ,
@@ -180,10 +180,10 @@ def mt(shape, **kwargs):
180
180
),
181
181
DecorateInfo (
182
182
unittest .expectedFailure ,
183
- ' TestCommon' ,
184
- ' test_non_standard_bool_values' ,
183
+ " TestCommon" ,
184
+ " test_non_standard_bool_values" ,
185
185
dtypes = [torch .bool ],
186
- device_type = ' mps'
186
+ device_type = " mps" ,
187
187
),
188
188
),
189
189
),
@@ -211,12 +211,12 @@ def mt(shape, **kwargs):
211
211
skips = (
212
212
DecorateInfo (
213
213
unittest .expectedFailure ,
214
- ' TestCommon' ,
215
- ' test_non_standard_bool_values' ,
214
+ " TestCommon" ,
215
+ " test_non_standard_bool_values" ,
216
216
dtypes = [torch .bool ],
217
- device_type = ' mps'
217
+ device_type = " mps" ,
218
218
),
219
- )
219
+ ),
220
220
),
221
221
SpectralFuncInfo (
222
222
"fft.hfft" ,
@@ -250,16 +250,16 @@ def mt(shape, **kwargs):
250
250
# FIXME[MPS]: test_out_warning_fft_hfft_mps crashes with `Invalid KernelDAG`
251
251
DecorateInfo (
252
252
unittest .skip ("Skipped on MPS due to hard crash" ),
253
- ' TestCommon' ,
254
- ' test_out_warning' ,
255
- device_type = ' mps'
253
+ " TestCommon" ,
254
+ " test_out_warning" ,
255
+ device_type = " mps" ,
256
256
),
257
257
DecorateInfo (
258
258
unittest .expectedFailure ,
259
- ' TestCommon' ,
260
- ' test_non_standard_bool_values' ,
259
+ " TestCommon" ,
260
+ " test_non_standard_bool_values" ,
261
261
dtypes = [torch .bool ],
262
- device_type = ' mps'
262
+ device_type = " mps" ,
263
263
),
264
264
),
265
265
),
@@ -307,10 +307,10 @@ def mt(shape, **kwargs):
307
307
),
308
308
DecorateInfo (
309
309
unittest .expectedFailure ,
310
- ' TestCommon' ,
311
- ' test_non_standard_bool_values' ,
310
+ " TestCommon" ,
311
+ " test_non_standard_bool_values" ,
312
312
dtypes = [torch .bool ],
313
- device_type = ' mps'
313
+ device_type = " mps" ,
314
314
),
315
315
),
316
316
),
@@ -351,10 +351,10 @@ def mt(shape, **kwargs):
351
351
),
352
352
DecorateInfo (
353
353
unittest .expectedFailure ,
354
- ' TestCommon' ,
355
- ' test_non_standard_bool_values' ,
354
+ " TestCommon" ,
355
+ " test_non_standard_bool_values" ,
356
356
dtypes = [torch .bool ],
357
- device_type = ' mps'
357
+ device_type = " mps" ,
358
358
),
359
359
),
360
360
),
@@ -379,10 +379,10 @@ def mt(shape, **kwargs):
379
379
skips = (
380
380
DecorateInfo (
381
381
unittest .expectedFailure ,
382
- ' TestCommon' ,
383
- ' test_non_standard_bool_values' ,
382
+ " TestCommon" ,
383
+ " test_non_standard_bool_values" ,
384
384
dtypes = [torch .bool ],
385
- device_type = ' mps'
385
+ device_type = " mps" ,
386
386
),
387
387
),
388
388
check_batched_gradgrad = False ,
@@ -412,10 +412,10 @@ def mt(shape, **kwargs):
412
412
skips = (
413
413
DecorateInfo (
414
414
unittest .expectedFailure ,
415
- ' TestCommon' ,
416
- ' test_non_standard_bool_values' ,
415
+ " TestCommon" ,
416
+ " test_non_standard_bool_values" ,
417
417
dtypes = [torch .bool ],
418
- device_type = ' mps'
418
+ device_type = " mps" ,
419
419
),
420
420
),
421
421
),
@@ -444,12 +444,12 @@ def mt(shape, **kwargs):
444
444
skips = (
445
445
DecorateInfo (
446
446
unittest .expectedFailure ,
447
- ' TestCommon' ,
448
- ' test_non_standard_bool_values' ,
447
+ " TestCommon" ,
448
+ " test_non_standard_bool_values" ,
449
449
dtypes = [torch .bool ],
450
- device_type = ' mps'
450
+ device_type = " mps" ,
451
451
),
452
- )
452
+ ),
453
453
),
454
454
SpectralFuncInfo (
455
455
"fft.ifft" ,
@@ -474,12 +474,12 @@ def mt(shape, **kwargs):
474
474
skips = (
475
475
DecorateInfo (
476
476
unittest .expectedFailure ,
477
- ' TestCommon' ,
478
- ' test_non_standard_bool_values' ,
477
+ " TestCommon" ,
478
+ " test_non_standard_bool_values" ,
479
479
dtypes = [torch .bool ],
480
- device_type = ' mps'
480
+ device_type = " mps" ,
481
481
),
482
- )
482
+ ),
483
483
),
484
484
SpectralFuncInfo (
485
485
"fft.ifft2" ,
@@ -511,12 +511,12 @@ def mt(shape, **kwargs):
511
511
skips = (
512
512
DecorateInfo (
513
513
unittest .expectedFailure ,
514
- ' TestCommon' ,
515
- ' test_non_standard_bool_values' ,
514
+ " TestCommon" ,
515
+ " test_non_standard_bool_values" ,
516
516
dtypes = [torch .bool ],
517
- device_type = ' mps'
517
+ device_type = " mps" ,
518
518
),
519
- )
519
+ ),
520
520
),
521
521
SpectralFuncInfo (
522
522
"fft.ifftn" ,
@@ -548,12 +548,12 @@ def mt(shape, **kwargs):
548
548
skips = (
549
549
DecorateInfo (
550
550
unittest .expectedFailure ,
551
- ' TestCommon' ,
552
- ' test_non_standard_bool_values' ,
551
+ " TestCommon" ,
552
+ " test_non_standard_bool_values" ,
553
553
dtypes = [torch .bool ],
554
- device_type = ' mps'
554
+ device_type = " mps" ,
555
555
),
556
- )
556
+ ),
557
557
),
558
558
SpectralFuncInfo (
559
559
"fft.ihfft" ,
@@ -573,8 +573,13 @@ def mt(shape, **kwargs):
573
573
torch .bool , * (() if (not SM53OrLater ) else (torch .half ,))
574
574
),
575
575
skips = (
576
- DecorateInfo (unittest .expectedFailure , 'TestCommon' , 'test_non_standard_bool_values' ,
577
- dtypes = [torch .bool ], device_type = 'mps' ),
576
+ DecorateInfo (
577
+ unittest .expectedFailure ,
578
+ "TestCommon" ,
579
+ "test_non_standard_bool_values" ,
580
+ dtypes = [torch .bool ],
581
+ device_type = "mps" ,
582
+ ),
578
583
),
579
584
check_batched_grad = False ,
580
585
),
@@ -610,9 +615,14 @@ def mt(shape, **kwargs):
610
615
DecorateInfo (unittest .expectedFailure , "TestCommon" , "test_out_warnings" ),
611
616
),
612
617
skips = (
613
- DecorateInfo (unittest .expectedFailure , 'TestCommon' , 'test_non_standard_bool_values' ,
614
- dtypes = [torch .bool ], device_type = 'mps' ),
615
- )
618
+ DecorateInfo (
619
+ unittest .expectedFailure ,
620
+ "TestCommon" ,
621
+ "test_non_standard_bool_values" ,
622
+ dtypes = [torch .bool ],
623
+ device_type = "mps" ,
624
+ ),
625
+ ),
616
626
),
617
627
SpectralFuncInfo (
618
628
"fft.ihfftn" ,
@@ -645,9 +655,14 @@ def mt(shape, **kwargs):
645
655
),
646
656
],
647
657
skips = (
648
- DecorateInfo (unittest .expectedFailure , 'TestCommon' , 'test_non_standard_bool_values' ,
649
- dtypes = [torch .bool ], device_type = 'mps' ),
650
- )
658
+ DecorateInfo (
659
+ unittest .expectedFailure ,
660
+ "TestCommon" ,
661
+ "test_non_standard_bool_values" ,
662
+ dtypes = [torch .bool ],
663
+ device_type = "mps" ,
664
+ ),
665
+ ),
651
666
),
652
667
SpectralFuncInfo (
653
668
"fft.irfft" ,
@@ -674,18 +689,18 @@ def mt(shape, **kwargs):
674
689
# FIXME[MPS]: test_out_warning_fft_irfft_mps crashes with `Invalid KernelDAG`
675
690
DecorateInfo (
676
691
unittest .skip ("Skipped on MPS due to hard crash" ),
677
- ' TestCommon' ,
678
- ' test_out_warning' ,
679
- device_type = ' mps' ,
692
+ " TestCommon" ,
693
+ " test_out_warning" ,
694
+ device_type = " mps" ,
680
695
),
681
696
DecorateInfo (
682
697
unittest .expectedFailure ,
683
- ' TestCommon' ,
684
- ' test_non_standard_bool_values' ,
698
+ " TestCommon" ,
699
+ " test_non_standard_bool_values" ,
685
700
dtypes = [torch .bool ],
686
- device_type = ' mps'
701
+ device_type = " mps" ,
687
702
),
688
- )
703
+ ),
689
704
),
690
705
SpectralFuncInfo (
691
706
"fft.irfft2" ,
@@ -719,18 +734,18 @@ def mt(shape, **kwargs):
719
734
# FIXME[MPS]: test_out_warning_fft_irfft2_mps crashes with `Invalid KernelDAG`
720
735
DecorateInfo (
721
736
unittest .skip ("Skipped on MPS due to hard crash" ),
722
- ' TestCommon' ,
723
- ' test_out_warning' ,
724
- device_type = ' mps' ,
737
+ " TestCommon" ,
738
+ " test_out_warning" ,
739
+ device_type = " mps" ,
725
740
),
726
741
DecorateInfo (
727
742
unittest .expectedFailure ,
728
- ' TestCommon' ,
729
- ' test_non_standard_bool_values' ,
743
+ " TestCommon" ,
744
+ " test_non_standard_bool_values" ,
730
745
dtypes = [torch .bool ],
731
- device_type = ' mps'
746
+ device_type = " mps" ,
732
747
),
733
- )
748
+ ),
734
749
),
735
750
SpectralFuncInfo (
736
751
"fft.irfftn" ,
@@ -761,9 +776,14 @@ def mt(shape, **kwargs):
761
776
)
762
777
],
763
778
skips = (
764
- DecorateInfo (unittest .expectedFailure , 'TestCommon' , 'test_non_standard_bool_values' ,
765
- dtypes = [torch .bool ], device_type = 'mps' ),
766
- )
779
+ DecorateInfo (
780
+ unittest .expectedFailure ,
781
+ "TestCommon" ,
782
+ "test_non_standard_bool_values" ,
783
+ dtypes = [torch .bool ],
784
+ device_type = "mps" ,
785
+ ),
786
+ ),
767
787
),
768
788
OpInfo (
769
789
"fft.fftshift" ,
0 commit comments