8000 Add torch.accelerator.device_index as accelerator's device switch context by guangyey · Pull Request #148864 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Add torch.accelerator.device_index as accelerator's device switch context #148864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 22 commits into from
Closed
Next Next commit
Update
[ghstack-poisoned]
  • Loading branch information
guangyey committed Mar 10, 2025
commit 959d05654678860578a42faf1fe1655fe9c3a15e
36 changes: 36 additions & 0 deletions torch/csrc/Device.cpp
8000
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ PyObject* THPDevice_New(const at::Device& device) {
throw python_error();
auto self_ = reinterpret_cast<THPDevice*>(self.get());
self_->device = device;
self_->context = nullptr;
return self.release();
}

Expand Down Expand Up @@ -183,6 +184,25 @@ static PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
at::impl::PythonTorchFunctionTLS::push_onto_stack(
std::make_shared<c10::SafePyObject>(
mode.release().ptr(), getPyInterpreter()));
auto device_type = at::accelerator::getAccelerator();
if (device_type.has_value() && device_type.value() == self->device.type() &&
self->device.has_index()) {
c10::DeviceIndex cur_device_idx = at::accelerator::getDeviceIndex();
at::accelerator::setDeviceIndex(self->device.index());
auto ctx_device_index =
THPObjectPtr(THPUtils_packDeviceIndex(cur_device_idx));
TORCH_CHECK(
!(self->context), "Device's context should not be initialized.");
auto dict = THPObjectPtr(PyDict_New());
if (!dict) {
throw python_error();
}
self->context = dict.release();
if (PyDict_SetItemString(
self->context, "_ctx_device_index", ctx_device_index.get()) < 0) {
throw python_error();
}
}
// So that with torch.device('cuda') as dev: works
Py_INCREF(self);
return self;
Expand All @@ -192,6 +212,22 @@ static PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
static PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
HANDLE_TH_ERRORS
at::impl::PythonTorchFunctionTLS::pop_stack();
auto device_type = at::accelerator::getAccelerator();
if (device_type.has_value() && device_type.value() == self->device.type() &&
self->device.has_index()) {
PyObject* py_device_index = nullptr;
if (PyDict_GetItemStringRef(
self->context, "_ctx_device_index", &py_device_index) < 0) {
throw python_error();
}
auto ctx_device_index = THPObjectPtr(py_device_index);
TORCH_INTERNAL_ASSERT(
ctx_device_index.get(),
"ctx_device_index should be present on the context dict.");
auto prev_device_index = THPUtils_unpackDeviceIndex(ctx_device_index.get());
at::accelerator::setDeviceIndex(prev_device_index);
Py_CLEAR(self->context);
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/Device.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
struct TORCH_API THPDevice {
PyObject_HEAD
at::Device device;
// Used to switch device context management, initialized lazily.
PyObject* context;
};

TORCH_API extern PyTypeObject THPDeviceType;
Expand Down
Loading
0