@@ -693,5 +693,83 @@ def test_event_synchronization_callback(self):
693
693
self .mock .assert_called_once_with (event ._as_parameter_ .value )
694
694
695
695
696
+ class TestSkipIfXpuDecroator (TestCase ):
697
+ def test_skip_if_xpu_for_class (self ):
698
+ script = """
699
+ from torch.testing._internal.common_utils import run_tests, skipIfXpu, TestCase
700
+
701
+ @skipIfXpu
702
+ class TestFoo(TestCase):
703
+ def test_case_1(self):
704
+ print("test_case_1")
705
+ self.assertTrue(True)
706
+
707
+ def test_case_2(self):
708
+ print("test_case_2")
709
+ self.assertTrue(True)
710
+
711
+ def test_case_3(self):
712
+ print("test_case_3")
713
+ self.assertTrue(True)
714
+
715
+ if __name__ == "__main__":
716
+ run_tests()
717
+ """
718
+ rc = (
719
+ subprocess .check_output (
720
+ [sys .executable , "-c" , script ], stderr = subprocess .STDOUT
721
+ )
722
+ .decode ("ascii" )
723
+ .strip ()
724
+ )
725
+ expected_ref = torch .xpu .is_available ()
726
+ expected_res = "test_case_1" not in rc
727
+ self .assertEqual (expected_ref , expected_res )
728
+ expected_res = "test_case_2" not in rc
729
+ self .assertEqual (expected_ref , expected_res )
730
+ expected_res = "test_case_3" not in rc
731
+ self .assertEqual (expected_ref , expected_res )
732
+
733
+ def test_skip_if_xpu_for_case (self ):
734
+ script = """
735
+ from torch.testing._internal.common_utils import run_tests, skipIfXpu, TestCase
736
+
737
+ class TestFoo(TestCase):
738
+ def test_case_1(self):
739
+ print("test_case_1")
740
+ self.assertTrue(True)
741
+
742
+ @skipIfXpu
743
+ def test_case_2(self):
744
+ print("test_case_2")
745
+ self.assertTrue(True)
746
+
747
+ def test_case_3(self):
748
+ print("test_case_3")
749
+ self.assertTrue(True)
750
+
751
+ if __name__ == "__main__":
752
+ run_tests()
753
+ """
754
+ rc = (
755
+ subprocess .check_output (
756
+ [sys .executable , "-c" , script ], stderr = subprocess .STDOUT
757
+ )
758
+ .decode ("ascii" )
759
+ .strip ()
760
+ )
761
+ expected_res = "test_case_1" in rc
762
+ self .assertTrue (expected_res )
763
+
764
+ expected_ref = torch .xpu .is_available ()
765
+ expected_res = "test_case_2" not in rc
766
+ self .assertEqual (expected_ref , expected_res )
767
+
768
+ expected_res = "test_case_3" in rc
769
+ self .assertTrue (expected_res )
770
+
771
+
772
+ instantiate_device_type_tests (TestSkipIfXpuDecroator , globals (), allow_xpu = True )
773
+
696
774
if __name__ == "__main__" :
697
775
run_tests ()
0 commit comments