diff --git a/torch/multiprocessing/reductions.py b/torch/multiprocessing/reductions.py index 4e4539396f8323..307356b3e02c4d 100644 --- a/torch/multiprocessing/reductions.py +++ b/torch/multiprocessing/reductions.py @@ -169,6 +169,7 @@ def rebuild_cuda_tensor( event_handle, event_sync_required, ): + storage_device = _device_from_uuid(storage_device) # If storage_handle is None, storage points to nullptr. if storage_handle is None or storage_size_bytes == 0: storage = storage_cls(0, dtype=dtype, device=storage_device, _internal=True) @@ -365,7 +366,7 @@ def reduce_tensor(tensor): tensor_offset, # tensor offset in its storage type(storage), tensor.dtype, - device, + _device_to_uuid(device), handle, # identifier which CUDA allocation is the storage in. storage_size_bytes, # size(in bytes) of the storage storage_offset_bytes, # offset(in bytes) of the storage in the CUDA allocation @@ -645,3 +646,14 @@ def init_reductions(): from torch.nn.parameter import Parameter reduction.register(Parameter, reduce_tensor) + + +def _device_to_uuid(device): + return str(torch.cuda.get_device_properties(device).uuid) + + +def _device_from_uuid(device_uuid): + for device in range(torch.cuda.device_count()): + if str(torch.cuda.get_device_properties(device).uuid) == device_uuid: + return device + raise Exception("Invalid device_uuid=" + device_uuid)