8000 Enable skipIfXpu support class-level skip · pytorch/pytorch@ecdcaf1 · GitHub
[go: up one dir, main page]

Skip to content

Commit ecdcaf1

Browse files
committed
Enable skipIfXpu support class-level skip
ghstack-source-id: 51da282 Pull Request resolved: #151420
1 parent a72642b commit ecdcaf1

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

test/test_xpu.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,5 +693,83 @@ def test_event_synchronization_callback(self):
693693
self.mock.assert_called_once_with(event._as_parameter_.value)
694694

695695

696+
class TestSkipIfXpuDecroator(TestCase):
697+
def test_skip_if_xpu_for_class(self):
698+
script = """
699+
from torch.testing._internal.common_utils import run_tests, skipIfXpu, TestCase
700+
701+
@skipIfXpu
702+
class TestFoo(TestCase):
703+
def test_case_1(self):
704+
print("test_case_1")
705+
self.assertTrue(True)
706+
707+
def test_case_2(self):
708+
print("test_case_2")
709+
self.assertTrue(True)
710+
711+
def test_case_3(self):
712+
print("test_case_3")
713+
self.assertTrue(True)
714+
715+
if __name__ == "__main__":
716+
run_tests()
717+
"""
718+
rc = (
719+
subprocess.check_output(
720+
[sys.executable, "-c", script], stderr=subprocess.STDOUT
721+
)
722+
.decode("ascii")
723+
.strip()
724+
)
725+
expected_ref = torch.xpu.is_available()
726+
expected_res = "test_case_1" not in rc
727+
self.assertEqual(expected_ref, expected_res)
728+
expected_res = "test_case_2" not in rc
729+
self.assertEqual(expected_ref, expected_res)
730+
expected_res = "test_case_3" not in rc
731+
self.assertEqual(expected_ref, expected_res)
732+
733+
def test_skip_if_xpu_for_case(self):
734+
script = """
735+
from torch.testing._internal.common_utils import run_tests, skipIfXpu, TestCase
736+
737+
class TestFoo(TestCase):
738+
def test_case_1(self):
739+
print("test_case_1")
740+
self.assertTrue(True)
741+
742+
@skipIfXpu
743+
def test_case_2(self):
744+
print("test_case_2")
745+
self.assertTrue(True)
746+
747+
def test_case_3(self):
748+
print("test_case_3")
749+
self.assertTrue(True)
750+
751+
if __name__ == "__main__":
752+
run_tests()
753+
"""
754+
rc = (
755+
subprocess.check_output(
756+
[sys.executable, "-c", script], stderr=subprocess.STDOUT
757+
)
758+
.decode("ascii")
759+
.strip()
760+
)
761+
expected_res = "test_case_1" in rc
762+
self.assertTrue(expected_res)
763+
764+
expected_ref = torch.xpu.is_available()
765+
expected_res = "test_case_2" not in rc
766+
self.assertEqual(expected_ref, expected_res)
767+
768+
expected_res = "test_case_3" in rc
769+
self.assertTrue(expected_res)
770+
771+
772+
instantiate_device_type_tests(TestSkipIfXpuDecroator, globals(), allow_xpu=True)
773+
696774
if __name__ == "__main__":
697775
run_tests()

torch/testing/_internal/common_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1917,16 +1917,23 @@ def xfailIfS390X(func):
19171917
return unittest.expectedFailure(func) if IS_S390X else func
19181918

19191919
def skipIfXpu(func=None, *, msg="test doesn't currently work on the XPU stack"):
1920-
def dec_fn(fn):
1921-
reason = f"skipIfXpu: {msg}"
1920+
reason = f"skipIfXpu: {msg}"
1921+
1922+
if isinstance(func, type):
1923+
if torch.xpu.is_available():
1924+
func.__unittest_skip__ = True # type: ignore[attr-defined]
1925+
func.__unittest_skip_why__ = reason # type: ignore[attr-defined]
1926+
return func
19221927

1928+
def dec_fn(fn):
19231929
@wraps(fn)
19241930
def wrapper(*args, **kwargs):
19251931
if TEST_XPU:
19261932
raise unittest.SkipTest(reason)
19271933
else:
19281934
return fn(*args, **kwargs)
19291935
return wrapper
1936+
19301937
if func:
19311938
return dec_fn(func)
19321939
return dec_fn

0 commit comments

Comments
 (0)
0