8000 Cherry-pick reverts dfe5669 and 1b3f8b7 (#143092) · pytorch/pytorch@aad1c16 · GitHub
[go: up one dir, main page]

Skip to content

Commit aad1c16

Browse files
Cherry-pick reverts dfe5669 and 1b3f8b7 (#143092)
* Revert "[RELAND] Add device-agnostic runtime Device/Stream C++ API (#138677)" This reverts commit 734bb01. Reverted #138677 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the new test is still very flaky on MacOS even when it does not segfault anymore ([comment](#133572 (comment))) * Revert "[RELAND] Add UTs for accelerator device-agnostic runtime APIs (#133572)" This reverts commit 2091194. Reverted #133572 on behalf of https://github.com/huydhn due to Sorry for reverting your change but the new test is still very flaky on MacOS even when it does not segfault anymore ([comment](#133572 (comment))) --------- Co-authored-by: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
1 parent af92bad commit aad1c16

File tree

7 files changed

+39
-195
lines changed

7 files changed

+39
-195
lines changed

aten/src/ATen/DeviceAccelerator.cpp

Lines changed: 4 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#include <ATen/Context.h>
22
#include <ATen/DeviceAccelerator.h>
3-
#include <c10/core/impl/VirtualGuardImpl.h>
4-
5-
namespace at::accelerator {
3+
namespace at {
64

75
std::optional<c10::DeviceType> getAccelerator(bool checked) {
86
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
@@ -39,8 +37,8 @@ std::optional<c10::DeviceType> getAccelerator(bool checked) {
3937
#undef DETECT_AND_ASSIGN_ACCELERATOR
4038
}
4139

42-
bool isAccelerator(c10::DeviceType device_type) {
43-
switch (device_type) {
40+
bool isAccelerator(c10::DeviceType d) {
41+
switch (d) {
4442
case at::kCUDA:
4543
case at::kMTIA:
4644
case at::kXPU:
@@ -54,50 +52,4 @@ bool isAccelerator(c10::DeviceType device_type) {
5452
}
5553
}
5654

57-
c10::DeviceIndex deviceCount() {
58-
const auto device_type = getAccelerator(false);
59-
if (!device_type.has_value()) {
60-
return static_cast<c10::DeviceIndex>(0);
61-
}
62-
c10::impl::VirtualGuardImpl impl(device_type.value());
63-
return static_cast<c10::DeviceIndex>(impl.deviceCount());
64-
}
65-
66-
void setDeviceIndex(c10::DeviceIndex device_index) {
67-
const auto device_type = getAccelerator(true).value();
68-
c10::impl::VirtualGuardImpl impl(device_type);
69-
impl.setDevice({device_type, device_index});
70-
}
71-
72-
c10::DeviceIndex getDeviceIndex() {
73-
const auto device_type = getAccelerator(true).value();
74-
c10::impl::VirtualGuardImpl impl(device_type);
75-
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
76-
}
77-
78-
void setCurrentStream(c10::Stream stream) {
79-
const auto device_type = getAccelerator(true).value();
80-
TORCH_CHECK(
81-
device_type == stream.device_type(),
82-
"stream's device type ",
83-
c10::DeviceTypeName(stream.device_type()),
84-
" doesn't match the current accelerator ",
85-
c10::DeviceTypeName(device_type));
86-
c10::impl::VirtualGuardImpl impl(device_type);
87-
impl.exchangeStream(stream);
88-
}
89-
90-
c10::Stream getCurrentStream(c10::DeviceIndex device_index) {
91-
const auto device_type = getAccelerator(true).value();
92-
c10::impl::VirtualGuardImpl impl(device_type);
93-
return impl.getStream({device_type, device_index});
94-
}
95-
96-
void synchronizeDevice(c10::DeviceIndex device_index) {
97-
const auto device_type = getAccelerator(true).value();
98-
c10::impl::VirtualGuardImpl impl(device_type);
99-
// impl.synchronizeDevice should can be safely called from any device
100-
impl.synchronizeDevice(device_index);
101-
}
102-
103-
} // namespace at::accelerator
55+
} // namespace at

aten/src/ATen/DeviceAccelerator.h

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
#include <ATen/detail/MTIAHooksInterface.h>
77
#include <optional>
88

9-
namespace at::accelerator {
10-
119
// This file defines the top level Accelerator concept for PyTorch.
1210
// A device is an accelerator per the definition here if:
1311
// - It is mutually exclusive with all other accelerators
@@ -17,39 +15,13 @@ namespace at::accelerator {
1715
// As of today, accelerator devices are (in no particular order):
1816
// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1
1917

18+
namespace at {
19+
2020
// Ensures that only one accelerator is available (at
2121
// compile time if possible) and return it.
2222
// When checked is true, the returned optional always has a value.
2323
TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);
2424

25-
// Check if the given device type is an accelerator.
26-
TORCH_API bool isAccelerator(c10::DeviceType device_type);
27-
28-
// Return the number of the device available. Note that this is *REQUIRED* to
29-
// not raise any exception.
30-
TORCH_API c10::DeviceIndex deviceCount();
31-
32-
// Set the current device index to the given device index.
33-
TORCH_API void setDeviceIndex(c10::DeviceIndex device_index);
34-
35-
// Get the current device index.
36-
TORCH_API c10::DeviceIndex getDeviceIndex();
25+
TORCH_API bool isAccelerator(c10::DeviceType d);
3726

38-
// Set the current stream to a given stream. Note that this API doesn't change
39-
// the current device index.
40-
TORCH_API void setCurrentStream(c10::Stream stream);
41-
42-
// Get the current stream of the given device index.
43-
TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index);
44-
45-
// Wait (by blocking the calling thread) until all the work previously enqueued
46-
// on the given device index has been completed.
47-
TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);
48-
49-
} // namespace at::accelerator
50-
51-
namespace at {
52-
// Keep BC only
53-
using at::accelerator::getAccelerator;
54-
using at::accelerator::isAccelerator;
5527
} // namespace at

test/test_accelerator.py

Lines changed: 0 additions & 73 deletions
This file was deleted.

test/test_cuda.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -725,14 +725,6 @@ def test_generic_stream_event(self):
725725
self.assertTrue(issubclass(type(cuda_event), torch.Event))
726726
self.assertTrue(torch.Event in type(cuda_event).mro())
727727

728-
def test_stream_compatibility(self):
729-
s1 = torch.cuda.Stream()
730-
s2 = torch.cuda.Stream()
731-
torch.accelerator.set_stream(s1)
732-
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
733-
torch.accelerator.set_stream(s2)
734-
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
735-
736728
def test_record_stream(self):
737729
cycles_per_ms = get_cycles_per_ms()
738730

test/tes 8000 t_xpu.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,14 +299,6 @@ def test_generic_stream_event(self):
299299
self.assertTrue(issubclass(type(xpu_event), torch.Event))
300300
self.assertTrue(torch.Event in type(xpu_event).mro())
301301

302-
def test_stream_compatibility(self):
303-
s1 = torch.xpu.Stream()
304-
s2 = torch.xpu.Stream()
305-
torch.accelerator.set_stream(s1)
306-
self.assertEqual(torch.accelerator.current_stream().stream_id, s1.stream_id)
307-
torch.accelerator.set_stream(s2)
308-
self.assertEqual(torch.accelerator.current_stream().stream_id, s2.stream_id)
309-
310302
def test_generator(self):
311303
torch.manual_seed(2024)
312304
g_state0 = torch.xpu.get_rng_state()

torch/csrc/DeviceAccelerator.cpp

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <c10/core/DeviceGuard.h>
12
#include <torch/csrc/DeviceAccelerator.h>
23
#include <torch/csrc/utils/device_lazy_init.h>
34

@@ -12,52 +13,68 @@ void initModule(PyObject* module) {
1213
});
1314

1415
m.def("_accelerator_deviceCount", []() {
15-
auto device_type = at::accelerator::getAccelerator(false);
16-
torch::utils::maybe_initialize_device(device_type);
17-
return at::accelerator::deviceCount();
16+
const auto device_type = at::getAccelerator(false);
17+
if (!device_type.has_value()) {
18+
return static_cast<c10::DeviceIndex>(0);
19+
}
20+
torch::utils::maybe_initialize_device(device_type.value());
21+
c10::impl::VirtualGuardImpl impl(device_type.value());
22+
return static_cast<c10::DeviceIndex>(impl.deviceCount());
1823
});
1924

2025
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
26+
const auto device_type = at::getAccelerator(true).value();
2127
// If device index is negative, no-op
2228
if (device_index < 0) {
2329
return;
2430
}
25-
const auto device_type = at::accelerator::getAccelerator(true).value();
2631
torch::utils::maybe_initialize_device(device_type);
27-
at::accelerator::setDeviceIndex(device_index);
32+
c10::impl::VirtualGuardImpl impl(device_type);
33+
impl.setDevice({device_type, device_index});
2834
});
2935

3036
m.def("_accelerator_getDeviceIndex", []() {
31-
const auto device_type = at::accelerator::getAccelerator(true).value();
37+
const auto device_type = at::getAccelerator(true).value();
3238
torch::utils::maybe_initialize_device(device_type);
33-
return at::accelerator::getDeviceIndex();
39+
c10::impl::VirtualGuardImpl impl(device_type);
40+
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
3441
});
3542

3643
m.def("_accelerator_setStream", [](c10::Stream stream) {
37-
const auto device_type = at::accelerator::getAccelerator(true).value();
44+
const auto device_type = at::getAccelerator(true).value();
45+
TORCH_CHECK(
46+
device_type == stream.device_type(),
47+
"stream's device type ",
48+
c10::DeviceTypeName(stream.device_type()),
49+
" doesn't match the current accelerator ",
50+
c10::DeviceTypeName(device_type));
3851
torch::utils::maybe_initialize_device(device_type);
52+
c10::impl::VirtualGuardImpl impl(device_type);
3953
// Set the current device to the device of stream
40-
if (at::accelerator::getDeviceIndex() != stream.device_index()) {
41-
at::accelerator::setDeviceIndex(stream.device_index());
54+
if (impl.getDevice().index() != stream.device_index()) {
55+
impl.setDevice(stream.device());
4256
}
43-
at::accelerator::setCurrentStream(stream);
57+
impl.exchangeStream(stream);
4458
});
4559

4660
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
47-
const auto device_type = at::accelerator::getAccelerator(true).value();
61+
const auto device_type = at::getAccelerator(true).value();
4862
torch::utils::maybe_initialize_device(device_type);
49-
return at::accelerator::getCurrentStream(device_index);
63+
c10::impl::VirtualGuardImpl impl(device_type);
64+
return impl.getStream({device_type, device_index});
5065
});
5166

5267
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
53-
const auto device_type = at::accelerator::getAccelerator(true).value();
68+
const auto device_type = at::getAccelerator(true).value();
5469
if (!torch::utils::is_device_initialized(device_type)) {
5570
return;
5671
}
5772
torch::utils::maybe_initialize_device(device_type);
73+
c10::impl::VirtualGuardImpl impl(device_type);
74+
// impl.synchronizeDevice should can be safely called from any device
5875
{
5976
py::gil_scoped_release no_gil;
60-
at::accelerator::synchronizeDevice(device_index);
77+
impl.synchronizeDevice(device_index);
6178
}
6279
});
6380
}

torch/csrc/utils/device_lazy_init.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,6 @@ inline void maybe_initialize_device(const at::TensorOptions& options) {
4646
maybe_initialize_device(device);
4747
}
4848

49-
inline void maybe_initialize_device(
50-
std::optional<at::DeviceType>& device_type) {
51-
if (!device_type.has_value()) {
52-
return;
53-
}
54-
maybe_initialize_device(device_type.value());
55-
}
56-
5749
bool is_device_initialized(at::DeviceType device_type);
5850

5951
} // namespace torch::utils

0 commit comments

Comments
 (0)
0