8000 Correct torch.xpu.is_bf16_supported return False if no XPU detected · pytorch/pytorch@260349b · GitHub
[go: up one dir, main page]

Skip to content

Commit 260349b

Browse files
committed
Correct torch.xpu.is_bf16_supported return False if no XPU detected
ghstack-source-id: 1de9b27 Pull Request resolved: #152317
1 parent 13dcf80 commit 260349b

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

test/test_xpu.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
find_library_location,
2323
IS_LINUX,
2424
IS_WINDOWS,
25-
NoTest,
2625
run_tests,
2726
suppress_warnings,
2827
TEST_XPU,
@@ -31,10 +30,6 @@
3130
from torch.utils.checkpoint import checkpoint_sequential
3231

3332

34-
if not TEST_XPU:
35-
print("XPU not available, skipping tests", file=sys.stderr)
36-
TestCase = NoTest # noqa: F811
37-
3833
TEST_MULTIXPU = torch.xpu.device_count() > 1
3934

4035
cpu_device = torch.device("cpu")
@@ -74,6 +69,7 @@
7469
]
7570

7671

72+
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
7773
class TestXpu(TestCase):
7874
def test_device_behavior(self):
7975
current_device = torch.xpu.current_device()
@@ -581,6 +577,7 @@ def test_dlpack_conversion(self):
581577
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True)
582578

583579

580+
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
584581
class TestXpuAutocast(TestAutocast):
585582
# These operators are not implemented on XPU backend and we can NOT fall back
586583
# them to CPU. So we have to skip them at this moment.
@@ -661,6 +658,7 @@ def test_xpu_autocast_dtype(self):
661658
self.assertEqual(result.dtype, torch.float16)
662659

663660

661+
@unittest.skipIf(not TEST_XPU, "XPU not available, skipping tests")
664662
class TestXpuTrace(TestCase):
665663
def setUp(self):
666664
torch._C._activate_gpu_trace()
@@ -723,5 +721,17 @@ def test_event_synchronization_callback(self):
723721
self.mock.assert_called_once_with(event._as_parameter_.value)
724722

725723

724+
class TestXPUAPISanity(TestCase):
725+
def test_is_bf16_supported(self):
726+
self.assertEqual(
727+
torch.xpu.is_bf16_supported(including_emulation=True),
728+
torch.xpu.is_available(),
729+
)
730+
731+
def test_get_arch_list(self):
732+
if not torch.xpu._is_compiled():
733+
self.assertEqual(len(torch.xpu.get_arch_list()), 0)
734+
735+
726736
if __name__ == "__main__":
727737
run_tests()

torch/xpu/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,14 @@ def is_available() -> bool:
6666
return device_count() > 0
6767

6868

69-
def is_bf16_supported():
69+
def is_bf16_supported(including_emulation: bool = True) -> bool:
7070
r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
71-
return True
71+
if not is_available():
72+
return False
73+
return (
74+
including_emulation
75+
or torch.xpu.get_device_properties().has_bfloat16_conversions
76+
)
7277

7378

7479
def is_initialized():

0 commit comments

Comments
 (0)
0