8000 [RELAND] Add UTs for accelerator device-agnostic runtime APIs (#133572) · pytorch/pytorch@2091194 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2091194

Browse files
guangyeypytorchmergebot
authored andcommitted
[RELAND] Add UTs for accelerator device-agnostic runtime APIs (#133572)
# Motivation This PR intends to add UTs for accelerator device-agnostic APIs. # Additional Context This PR is relanded. It is reverted because `torch.Event` doesn't support mps backend. We have fixed it in #142468. The previous commit is 952514f Pull Request resolved: #133572 Approved by: https://github.com/EikanWang, https://github.com/albanD ghstack dependencies: #142468
1 parent 8815402 commit 2091194

File tree

3 files changed

+89
-0
lines changed

3 files changed

+89
-0
lines changed

test/test_accelerator.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Owner(s): ["module: tests"]
2+
3+
import sys
4+
import unittest
5+
6+
import torch
7+
from torch.testing._internal.common_utils import NoTest, run_tests, TestCase
8+
9+
10+
if not torch.accelerator.is_available():
11+
print("No available accelerator detected, skipping tests", file=sys.stderr)
12+
TestCase = NoTest # noqa: F811
13+
14+
TEST_MULTIACCELERATOR = torch.accelerator.device_count() > 1
15+
16+
17+
class TestAccelerator(TestCase):
18+
def test_current_accelerator(self):
19+
self.assertTrue(torch.accelerator.is_available())
20+
accelerators = ["cuda", "xpu", "mps"]
21+
for accelerator in accelerators:
22+
if torch.get_device_module(accelerator).is_available():
23+
self.assertEqual(
24+
torch.accelerator.current_accelerator().type, accelerator
25+
)
26+
self.assertIsNone(torch.accelerator.current_accelerator().index)
27+
with self.assertRaisesRegex(
28+
ValueError, "doesn't match the current accelerator"
29+
):
30+
torch.accelerator.set_device_idx("cpu")
31+
32+
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
33+
def test_generic_multi_device_behavior(self):
34+
orig_device = torch.accelerator.current_device_idx()
35+
target_device = (orig_device + 1) % torch.accelerator.device_count()
36+
37+
torch.accelerator.set_device_idx(target_device)
38+
self.assertEqual(target_device, torch.accelerator.current_device_idx())
39+
torch.accelerator.set_device_idx(orig_device)
40+
self.assertEqual(orig_device, torch.accelerator.current_device_idx())
41+
42+
s1 = torch.Stream(target_device)
43+
torch.accelerator.set_stream(s1)
44+
self.assertEqual(target_device, torch.accelerator.current_device_idx())
45+
torch.accelerator.synchronize(orig_device)
46+
self.assertEqual(target_device, torch.accelerator.current_device_idx())
47+
48+
def test_generic_stream_behavior(self):
49+
s1 = torch.Stream()
50+
s2 = torch.Stream()
51+
torch.accelerator.set_stream(s1)
52+
self.assertEqual(torch.accelerator.current_stream(), s1)
53+
event = torch.Event()
54+
a = torch.randn(100)
55+
b = torch.randn(100)
56+
c = a + b
57+
torch.accelerator.set_stream(s2)
58+
self.assertEqual(torch.accelerator.current_stream(), s2)
59+
a_acc = a.to(torch.accelerator.current_accelerator(), non_blocking=True)
60+
b_acc = b.to(torch.accelerator.current_accelerator(), non_blocking=True)
61+
torch.accelerator.set_stream(s1)
62+
self.assertEqual(torch.accelerator.current_stream(), s1)
63+
event.record(s2)
64+
event.synchronize()
65+
c_acc = a_acc + b_acc
66+
event.record(s2)
67+
torch.accelerator.synchronize()
68+
self.assertTrue(event.query())
69+
self.assertEqual(c_acc.cpu(), c)
70+
71+
72+
if __name__ == "__main__":
73+
run_tests()

test/test_cuda.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,6 +725,14 @@ def test_generic_stream_event(self):
725725
self.assertTrue(issubclass(type(cuda_event), torch.Event))
726726
self.assertTrue(torch.Event in type(cuda_event).mro())
727727

728+
def test_stream_compatibility(self):
729+
s1 = torch.cuda.Stream()
730+
s2 = torch.cuda.Stream()
731+
torch.accelerator.set_stream(s1)
732+
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
733+
torch.accelerator.set_stream(s2)
734+
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
735+
728736
def test_record_stream(self):
729737
cycles_per_ms = get_cycles_per_ms()
730738

test/test_xpu.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,14 @@ def test_generic_stream_event(self):
299299
self.assertTrue(issubclass(type(xpu_event), torch.Event))
300300
self.assertTrue(torch.Event in type(xpu_event).mro())
301301

302+
def test_stream_compatibility(self):
303+
s1 = torch.xpu.Stream()
304+
s2 = torch.xpu.Stream()
305+
torch.accelerator.set_stream(s1)
306+
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
307+
torch.accelerator.set_stream(s2)
308+
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
309+
302310
def test_generator(self):
303311
torch.manual_seed(2024)
304312
g_state0 = torch.xpu.get_rng_state()

0 commit comments

Comments
 (0)
0