|
22 | 22 | find_library_location,
|
23 | 23 | IS_LINUX,
|
24 | 24 | IS_WINDOWS,
|
25 |
| - NoTest, |
26 | 25 | run_tests,
|
27 | 26 | suppress_warnings,
|
28 | 27 | TEST_XPU,
|
|
31 | 30 | from torch.utils.checkpoint import checkpoint_sequential
|
32 | 31 |
|
33 | 32 |
|
34 |
| -if not TEST_XPU: |
35 |
| - print("XPU not available, skipping tests", file=sys.stderr) |
36 |
| - TestCase = NoTest # noqa: F811 |
37 |
| - |
38 | 33 | TEST_MULTIXPU = torch.xpu.device_count() > 1
|
39 | 34 |
|
40 | 35 | cpu_device = torch.device("cpu")
|
|
74 | 69 | ]
|
75 | 70 |
|
76 | 71 |
|
| 72 | +@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests") |
77 | 73 | class TestXpu(TestCase):
|
78 | 74 | def test_device_behavior(self):
|
79 | 75 | current_device = torch.xpu.current_device()
|
@@ -581,6 +577,7 @@ def test_dlpack_conversion(self):
|
581 | 577 | instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
|
582 | 578 |
|
583 | 579 |
|
| 580 | +@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests") |
584 | 581 | class TestXpuAutocast(TestAutocast):
|
585 | 582 | # These operators are not implemented on XPU backend and we can NOT fall back
|
586 | 583 | # them to CPU. So we have to skip them at this moment.
|
@@ -661,6 +658,7 @@ def test_xpu_autocast_dtype(self):
|
661 | 658 | self.assertEqual(result.dtype, torch.float16)
|
662 | 659 |
|
663 | 660 |
|
| 661 | +@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests") |
664 | 662 | class TestXpuTrace(TestCase):
|
665 | 663 | def setUp(self):
|
666 | 664 | torch._C._activate_gpu_trace()
|
@@ -723,5 +721,17 @@ def test_event_synchronization_callback(self):
|
723 | 721 | self.mock.assert_called_once_with(event._as_parameter_.value)
|
724 | 722 |
|
725 | 723 |
|
| 724 | +class TestXPUAPISanity(TestCase): |
| 725 | + def test_is_bf16_supported(self): |
| 726 | + self.assertEqual( |
| 727 | + torch.xpu.is_bf16_supported(including_emulation=True), |
| 728 | + torch.xpu.is_available(), |
| 729 | + ) |
| 730 | + |
| 731 | + def test_get_arch_list(self): |
| 732 | + if not torch.xpu._is_compiled(): |
| 733 | + self.assertEqual(len(torch.xpu.get_arch_list()), 0) |
| 734 | + |
| 735 | + |
726 | 736 | if __name__ == "__main__":
|
727 | 737 | run_tests()
|
0 commit comments