### 🐛 Describe the bug ```python import torch def foo(x): return torch.ones(x.shape) x = torch.randn(3) with torch.device("cuda"): print(torch.get_default_device()) # cpu print(foo(x)) ``` cc: @ezyang (original author of torch.device as ctx support - https://github.com/pytorch/pytorch/pull/91796) ### Versions main branch cc @albanD