8000 Support torch.device as accelerator's device switch context · pytorch/pytorch@21536ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 21536ee

Browse files
committed
Support torch.device as accelerator's device switch context
ghstack-source-id: 01a7f2c Pull Request resolved: #148864
1 parent 3f069e7 commit 21536ee

File tree

5 files changed

+90
-5
lines changed

5 files changed

+90
-5
lines changed

test/test_accelerator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,23 @@ def test_current_stream_query(self):
7979
):
8080
torch.accelerator.current_stream(other_device)
8181

82+
def test_device_context_manager(self):
83+
prev_device = torch.accelerator.current_device_index()
84+
with torch.device("cpu"):
85+
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
86+
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
87+
with torch.accelerator.current_accelerator():
88+
self.assertEqual(torch.accelerator.current_device_index(), prev_device)
89+
90+
@unittest.skipIf(not TEST_MULTIACCELERATOR, "only one accelerator detected")
91+
def test_multi_device_context_manager(self):
92+
src_device = 0
93+
dst_device = 1
94+
torch.accelerator.set_device_index(src_device)
95+
with torch.device(dst_device):
96+
self.assertEqual(torch.accelerator.current_device_index(), dst_device)
97+
self.assertEqual(torch.accelerator.current_device_index(), src_device)
98+
8299
def test_stream_context_manager(self):
83100
prev_stream = torch.accelerator.current_stream()
84101
with torch.Stream() as s:

test/test_cuda.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,20 @@ def test_record_stream_on_shifted_view(self):
10111011

10121012
self.assertNotEqual(try_realloc.data_ptr(), data_ptr)
10131013

1014+
def test_device_context_manager(self):
1015+
prev_device = torch.cuda.current_device()
1016+
with torch.device("cpu"):
1017+
self.assertEqual(torch.cuda.current_device(), prev_device)
1018+
self.assertEqual(torch.cuda.current_device(), prev_device)
1019+
if not torch.cuda.device_count() > 1:
1020+
return
1021+
src_device = 0
1022+
dst_device = 1
1023+
torch.cuda.set_device(src_device)
1024+
with torch.device(dst_device):
1025+
self.assertEqual(torch.cuda.current_device(), 1)
1026+
self.assertEqual(torch.cuda.set_device(), src_device)
1027+
10141028
def test_stream_context_manager(self):
10151029
prev_stream = torch.cuda.current_stream()
10161030
with torch.cuda.Stream() as stream:

test/test_xpu.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,20 @@ def test_stream_compatibility(self):
309309
with self.assertRaisesRegex(RuntimeError, "The device index is out of range"):
310310
torch.accelerator.current_stream(torch.accelerator.device_count())
311311

312+
def test_device_context_manager(self):
313+
prev_device = torch.xpu.current_device()
314+
with torch.device("cpu"):
315+
self.assertEqual(torch.xpu.current_device(), prev_device)
316+
self.assertEqual(torch.xpu.current_device(), prev_device)
317+
if not torch.xpu.device_count() > 1:
318+
return
319+
src_device = 0
320+
dst_device = 1
321+
torch.xpu.set_device(src_device)
322+
with torch.device(dst_device):
323+
self.assertEqual(torch.xpu.current_device(), 1)
324+
self.assertEqual(torch.xpu.set_device(), src_device)
325+
312326
def test_stream_context_manager(self):
313327
prev_stream = torch.xpu.current_stream()
314328
with torch.xpu.Stream() as stream:

torch/csrc/Device.cpp

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ PyObject* THPDevice_New(const at::Device& device) {
2424
throw python_error();
2525
auto self_ = reinterpret_cast<THPDevice*>(self.get());
2626
self_->device = device;
27+
self_->context = nullptr;
2728
return self.release();
2829
}
2930

@@ -176,22 +177,59 @@ static PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) {
176177
END_HANDLE_TH_ERRORS
177178
}
178179

179-
static PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
180+
static PyObject* THPDevice_enter(PyObject* _self, PyObject* noargs) {
180181
HANDLE_TH_ERRORS
181182
py::object mode = py::module::import("torch.utils._device")
182-
.attr("DeviceContext")(py::handle(self));
183+
.attr("DeviceContext")(py::handle(_self));
183184
at::impl::PythonTorchFunctionTLS::push_onto_stack(
184185
std::make_shared<c10::SafePyObject>(
185186
mode.release().ptr(), getPyInterpreter()));
187+
auto self = (THPDevice*)_self;
188+
auto device_type = at::accelerator::getAccelerator();
189+
if (device_type.has_value() && device_type.value() == self->device.type() &&
190+
self->device.has_index()) {
191+
c10::DeviceIndex cur_device_idx = at::accelerator::getDeviceIndex();
192+
at::accelerator::setDeviceIndex(self->device.index());
193+
auto ctx_device_index =
194+
THPObjectPtr(THPUtils_packDeviceIndex(cur_device_idx));
195+
TORCH_CHECK(
196+
!(self->context), "Device's context should not be initialized.");
197+
auto dict = THPObjectPtr(PyDict_New());
198+
if (!dict) {
199+
throw python_error();
200+
}
201+
self->context = dict.release();
202+
if (PyDict_SetItemString(
203+
self->context, "_ctx_device_index", ctx_device_index.get()) < 0) {
204+
throw python_error();
205+
}
206+
}
186207
// So that with torch.device('cuda') as dev: works
187-
Py_INCREF(self);
188-
return self;
208+
Py_INCREF(_self);
209+
return _self;
189210
END_HANDLE_TH_ERRORS
190211
}
191212

192-
static PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
213+
static PyObject* THPDevice_exit(PyObject* _self, PyObject* unused) {
193214
HANDLE_TH_ERRORS
194215
at::impl::PythonTorchFunctionTLS::pop_stack();
216+
auto self = (THPDevice*)_self;
217+
auto device_type = at::accelerator::getAccelerator();
218+
if (device_type.has_value() && device_type.value() == self->device.type() &&
219+
self->device.has_index()) {
220+
PyObject* py_device_index = nullptr;
221+
if (PyDict_GetItemStringRef(
222+
self->context, "_ctx_device_index", &py_device_index) < 0) {
223+
throw python_error();
224+
}
225+
auto ctx_device_index = THPObjectPtr(py_device_index);
226+
TORCH_INTERNAL_ASSERT(
227+
ctx_device_index.get(),
228+
"ctx_device_index should be present on the context dict.");
229+
auto prev_device_index = THPUtils_unpackDeviceIndex(ctx_device_index.get());
230+
at::accelerator::setDeviceIndex(prev_device_index);
231+
Py_CLEAR(self->context);
232+
}
195233
Py_RETURN_NONE;
196234
END_HANDLE_TH_ERRORS
197235
}

torch/csrc/Device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
struct TORCH_API THPDevice {
1010
PyObject_HEAD
1111
at::Device device;
12+
// Used to switch device context management, initialized lazily.
13+
PyObject* context;
1214
};
1315

1416
TORCH_API extern PyTypeObject THPDeviceType;

0 commit comments

Comments
 (0)
0