8000 [c10d] Fix extra CUDA context created by barrier · pytorch/pytorch@99138ee · GitHub
[go: up one dir, main page]

Skip to content

Commit 99138ee

Browse files
committed
[c10d] Fix extra CUDA context created by barrier
Fixes #149119. In ProcessGroup.hpp, we create a dummy tensor for dispatching. This requires a correct device index. This PR uses `device_id` given by user when calling `init_process_group`. This PR also uses `torch._C._get_accelerator()` to determine the device type. ghstack-source-id: 96c32b9 Pull Request resolved: #149144
1 parent 1341794 commit 99138ee

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3505,17 +3505,6 @@ def test_nccl_barrier_device_ids(self):
35053505

35063506
c10d.barrier(device_ids=[self.rank])
35073507

3508-
@requires_nccl()
3509-
@skip_if_lt_x_gpu(2)
3510-
def test_nccl_barrier_device_ids_function_argument(self):
3511-
store = c10d.FileStore(self.file_name, self.world_size)
3512-
c10d.init_process_group(
3513-
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
3514-
)
3515-
3516-
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
3517-
c10d.barrier(device_ids=self.rank)
3518-
35193508
@requires_nccl()
35203509
@skip_if_lt_x_gpu(2)
35213510
def test_unwaited(self) -> None:

torch/distributed/distributed_c10d.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4596,29 +4596,42 @@ def barrier(
45964596
group (ProcessGroup, optional): The process group to work on. If None,
45974597
the default process group will be used.
45984598
async_op (bool, optional): Whether this op should be an async op
4599-
device_ids ([int], optional): List of device/GPU ids.
4599+
device_ids ([int], optional): List of device/GPU ids. Only one id is expected.
46004600
46014601
Returns:
46024602
Async work handle, if async_op is set to True.
46034603
None, if not async_op or if not part of the group
46044604
46054605
.. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective.
46064606
"""
4607+
group = group or _get_default_group()
4608+
46074609
if _rank_not_in_group(group):
46084610
_warn_not_in_group("barrier")
46094611
return
46104612

46114613
opts = BarrierOptions()
4612-
opts.device = torch.device(_get_object_coll_device(group))
4613-
if device_ids is not None:
4614-
if isinstance(device_ids, list):
4615-
opts.device_ids = device_ids
4616-
else:
4617-
raise TypeError(
4618-
"Invalid function argument: device_ids type should be List[int]"
4619-
)
4614+
# Detect the accelerator on the machine. If no accelerator is available, it
4615+
# returns CPU.
4616+
device = torch._C._get_accelerator()
4617+
if isinstance(device_ids, list):
4618+
opts.device_ids = device_ids
4619+
# use only the first device id
4620+
opts.device = torch.device(device.type, device_ids[0])
4621+
elif getattr(group, "bound_device_id", None) is not None:
4622+
# Use device id from `init_process_group(device_id=...)`
4623+
opts.device = group.bound_device_id # type: ignore[assignment]
4624+
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
4625+
opts.device = torch.device("cpu")
4626+
else:
4627+
# Use the current device set by the user. If user did not set any, this
4628+
# may use default device 0, causing issues like hang or all processes
4629+
# creating context on device 0.
4630+
opts.device = device
4631+
warnings.warn( # warn only once
4632+
"No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. "
4633+
)
46204634

4621-
group = group or _get_default_group()
46224635
work = group.barrier(opts=opts)
46234636

46244637
if async_op:

0 commit comments

Comments
 (0)
0