10000 `torch.device.__enter__` does not affect `get_default_device` despite taking precedence over `set_default_device` · Issue #148874 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

torch.device.__enter__ does not affect get_default_device despite taking precedence over set_default_device #148874

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
8000
ringohoffman opened this issue Mar 10, 2025 · 2 comments
Labels
module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ringohoffman
Copy link
Contributor
ringohoffman commented Mar 10, 2025

🐛 Describe the bug

Using a torch.device as a context manager takes precedence over set_default_device, but this isn't reflected by the return value of get_default_device.

import torch
import torch.utils._device

torch.set_default_device("cuda:1")

with torch.device("cuda:0"):
    print(f"get_default_device(): {torch.get_default_device()}")
    print(f"CURRENT_DEVICE: {torch.utils._device.CURRENT_DEVICE}")
    print(f"actual current device: {torch.tensor(()).device}")
get_default_device(): cuda:1
CURRENT_DEVICE: cuda:1
actual current device: cuda:0

I feel like calling __enter__ on the DeviceContext created in torch.device's C++ __enter__ implementation and __exit__ in the C++ __exit__ implementation might be a solution.

static PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) {
HANDLE_TH_ERRORS
py::object mode = py::module::import("torch.utils._device")
.attr("DeviceContext")(py::handle(self));
at::impl::PythonTorchFunctionTLS::push_onto_stack(
std::make_shared<c10::SafePyObject>(
mode.release().ptr(), getPyInterpreter()));
// So that with torch.device('cuda') as dev: works
Py_INCREF(self);
return self;
END_HANDLE_TH_ERRORS
}
static PyObject* THPDevice_exit(PyObject* self, PyObject* unused) {
HANDLE_TH_ERRORS
at::impl::PythonTorchFunctionTLS::pop_stack();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}

def __torch_function__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func in _device_constructors() and kwargs.get('device') is None:
kwargs['device'] = self.device
return func(*args, **kwargs)

pytorch/torch/__init__.py

Lines 1134 to 1147 in 00199ac

def get_default_device() -> "torch.device":
r"""Gets the default ``torch.Tensor`` to be allocated on ``device``"""
global _GLOBAL_DEVICE_CONTEXT
if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"):
device = _GLOBAL_DEVICE_CONTEXT.device_context.device
if device.index is not None:
return device
else:
# TODO: Call like get_device_index() method corresponding to
# each device type
return torch.tensor([]).device
else:
return torch.device("cpu")

cc: @ezyang

Versions

torch==2.6.0

cc @albanD

@ringohoffman
Copy link
Contributor Author
ringohoffman commented Mar 10, 2025

Actually if you do this, it crashes the second time you call set_default_device:

import torch
import torch.utils._device

with torch.device("cuda:0"):
    torch.set_default_device("cuda:1")
    print(f"get_default_device(): {torch.get_default_device()}")
    print(f"CURRENT_DEVICE: {torch.utils._device.CURRENT_DEVICE}")
    print(f"actual current device: {torch.tensor(()).device}")

with torch.device("cuda:0"):
    torch.set_default_device("cuda:1")
    print(f"get_default_device(): {torch.get_default_device()}")
    print(f"CURRENT_DEVICE: {torch.utils._device.CURRENT_DEVICE}")
    print(f"actual current device: {torch.tensor(()).device}")
get_default_device(): cuda:1
CURRENT_DEVICE: cuda:1
actual current device: cuda:0

AssertionError                            Traceback (most recent call last)
Cell In[1], line 11
      8     print(f"actual current device: {torch.tensor(()).device}")
     10 with torch.device("cuda:0"):
---> 11     torch.set_default_device("cuda:1")
     12     print(f"get_default_device(): {torch.get_default_device()}")
     13     print(f"CURRENT_DEVICE: {torch.utils._device.CURRENT_DEVICE}")

File ~//lib/python3.10/site-packages/torch/__init__.py:1200, in set_default_device(device)
   1198     device_context = _GLOBAL_DEVICE_CONTEXT.device_context
   1199     if device_context is not None:
-> 1200         device_context.__exit__(None, None, None)
   1202 if device is None:
   1203     device_context = None

File ~//lib/python3.10/site-packages/torch/utils/_device.py:90, in DeviceContext.__exit__(self, exc_type, exc_val, exc_tb)
     88 for _ in range(_len_torch_function_stack() - 1):
     89     mode = _pop_mode()
---> 90     assert not isinstance(mode, DeviceContext)
     91     cur_stack.append(mode)
     93 if _len_torch_function_stack() > 0:

AssertionError: 

@soulitzer soulitzer added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: python frontend For issues relating to PyTorch's Python frontend labels Mar 10, 2025
@vadimkantorov
Copy link
Contributor
vadimkantorov commented May 10, 2025

This is especially weird with torch.device("meta"):

import torch

with torch.device('meta'):
    print(torch.empty(()).device)
    print(torch.get_default_device())

# meta
# cpu

E.g. probably here:

https://github.com/huggingface/transformers/blob/716819b8309324302e00a3488a3c3d6faa427f79/src/transformers/modeling_utils.py#L833-L834:

if is_fsdp_enabled():
    param_device = "cpu" if is_local_dist_rank_0() else "meta"

it should then say: param_device = "cpu" if (is_local_dist_rank_0() and torch.get_default_device() != ""meta") else "meta" (using "cpu" with default device being "meta" leads to errors like huggingface/transformers#38066)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: python frontend For issues relating to PyTorch's Python frontend triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants
0