8000 [DataParallel] Skip for MPS device · pytorch/pytorch@a2a6723 · GitHub
[go: up one dir, main page]

Skip to content

Commit a2a6723

Browse files
authored
[DataParallel] Skip for MPS device
As `torch._C._scatter` is only defined for CUDA/ROCm (and may be XPU?) This is a regression introduced by #141098 that went unnoticed due to #142206 Test plan: ``` python test_autograd.py -v -k test_dataparallel_saved_tensors_hooks ``` Before this change it failed with ``` ERROR: test_dataparallel_saved_tensors_hooks (__main__.TestMultithreadAutograd.test_dataparallel_saved_tensors_hooks) ---------------------------------------------------------------------- Traceback (most recent call last): File "/Users/malfet/git/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3108, in wrapper method(*args, **kwargs) ~~~~~~^^^^^^^^^^^^^^^^^ File "/Users/malfet/git/pytorch/pytorch/test/test_autograd.py", line 13074, in test_dataparallel_saved_tensors_hooks model = torch.nn.DataParallel(Model()) File "/Users/malfet/git/pytorch/pytorch/torch/nn/parallel/data_parallel.py", line 153, in __init__ raise RuntimeError("no available devices were found") RuntimeError: no available devices were found ``` After this change it passes
1 parent bef1039 commit a2a6723

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch/nn/parallel/data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def __init__(
141141
super().__init__()
142142
torch._C._log_api_usage_once("torch.nn.parallel.DataParallel")
143143
device_type = _get_available_device_type()
144-
if device_type is None:
144+
if device_type is None or device_type == "mps":
145145
self.module = module
146146
self.device_ids = []
147147
return

0 commit comments

Comments
 (0)
0