8000 [c10d] Fix extra CUDA context created by barrier · pytorch/pytorch@b5cc381 · GitHub
[go: up one dir, main page]

Skip to content

Commit b5cc381

Browse files
committed
[c10d] Fix extra CUDA context created by barrier
Fixes #149119. In ProcessGroup.hpp, we create a dummy tensor for dispatching. This requires a correct device index. This PR uses `device_id` given by user when calling `init_process_group`. This PR also uses `torch._C._get_accelerator()` to determine the device type. ghstack-source-id: 258020d Pull Request resolved: #149144
1 parent 8d08b49 commit b5cc381

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

torch/distributed/distributed_c10d.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4654,30 +4654,43 @@ def barrier(
46544654
group (ProcessGroup, optional): The process group to work on. If None,
46554655
the default process group will be used.
46564656
async_op (bool, optional): Whether this op should be an async op
4657-
device_ids ([int], optional): List of device/GPU ids.
4657+
device_ids ([int], optional): List of device/GPU ids. Only one id is expected.
46584658
46594659
Returns:
46604660
Async work handle, if async_op is set to True.
46614661
None, if not async_op or if not part of the group
46624662
46634663
.. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective.
46644664
"""
4665+
group = group or _get_default_group()
4666+
46654667
if _rank_not_in_group(group):
46664668
_warn_not_in_group("barrier")
46674669
return
46684670

46694671
opts = BarrierOptions()
4670-
opts.device = torch.device(_get_object_coll_device(group))
46714672
opts.asyncOp = async_op
4672-
if device_ids is not None:
4673-
if isinstance(device_ids, list):
4674-
opts.device_ids = device_ids
4675-
else:
4676-
raise TypeError(
4677-
"Invalid function argument: device_ids type should be List[int]"
4678-
)
4673+
# Detect the accelerator on the machine. If no accelerator is available, it
4674+
# returns CPU.
4675+
device = torch._C._get_accelerator()
4676+
if isinstance(device_ids, list):
4677+
opts.device_ids = device_ids
4678+
# use only the first device id
4679+
opts.device = torch.device(device.type, device_ids[0])
4680+
elif group.bound_device_id is not None:
4681+
# Use device id from `init_process_group(device_id=...)`
4682+
opts.device = group.bound_device_id
4683+
elif device.type == "cpu" or get_backend(group) == Backend.GLOO:
4684+
opts.device = torch.device("cpu")
4685+
else:
4686+
# Use the current device set by the user. If user did not set any, this
4687+
# may use default device 0, causing issues like hang or all processes
4688+
# creating context on device 0.
4689+
opts.device = device
4690+
warnings.warn( # warn only once
4691+
"No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. "
4692+
)
46794693

4680-
group = group or _get_default_group()
46814694
work = group.barrier(opts=opts)
46824695

46834696
if async_op:

0 commit comments

Comments
 (0)
0