8000 [c10d] Fix extra CUDA context created by barrier (#149144) · pytorch/pytorch@a8f727c · GitHub
[go: up one dir, main page]

Skip to content

Commit a8f727c

Browse files
kwen2501pytorchmergebot
authored andcommitted
[c10d] Fix extra CUDA context created by barrier (#149144)
Fixes #149119. In ProcessGroup.hpp, we create a dummy tensor for dispatching. This requires a correct device index. This PR uses `device_id` given by user when calling `init_process_group`. This PR also uses `torch._C._get_accelerator()` to determine the device type. Pull Request resolved: #149144 Approved by: https://github.com/XilunWu, https://github.com/fduwjj, https://github.com/cyyever
1 parent 12a8b70 commit a8f727c

File tree

2 files changed

+23
-21
lines changed

2 files changed

+23
-21
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3516,17 +3516,6 @@ def test_nccl_barrier_device_ids(self):
35163516

35173517
c10d.barrier(device_ids=[self.rank])
35183518

3519-
@requires_nccl()
3520-
@skip_if_lt_x_gpu(2)
3521-
def test_nccl_barrier_device_ids_function_argument(self):
3522-
store = c10d.FileStore(self.file_name, self.world_size)
3523-
c10d.init_process_group(
3524-
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
3525-
)
3526-
3527-
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
3528-
c10d.barrier(device_ids=self.rank)
3529-
35303519
@requires_nccl()
35313520
@skip_if_lt_x_gpu(2)
35323521
def test_unwaited(self) -> None:

torch/distributed/distributed_c10d.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4730,30 +4730,43 @@ def barrier(
47304730
group (ProcessGroup, optional): The process group to work on. If None,
47314731
the default process group will be used.
47324732
async_op (bool, optional): Whether this op should be an async op
4733-
device_ids ([int], optional): List of device/GPU ids.
4733+
device_ids ([int], optional): List of device/GPU ids. Only one id is expected.
47344734
47354735
Returns:
47364736
Async work handle, if async_op is set to True.
47374737
None, if not async_op or if not part of the group
47384738
47394739
.. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective.
47404740
"""
4741+
group = group or _get_default_group()
4742+
47414743
if _rank_not_in_group(group):
47424744
_warn_not_in_group("barrier")
47434745
return
47444746

47454747
opts = BarrierOptions()
4746-
opts.device = torch.device(_get_object_coll_device(group))
47474748
opts.asyncOp = async_op
4748-
if device_ids is not None:
4749-
if isinstance(device_ids, list):
4750-
opts.device_ids = device_ids
4751-
else:
4752-
raise TypeError(
4753-
"Invalid function argument: device_ids type should be List[int]"
4754-
)
4749+
# Detect the accelerator on the machine. If no accelerator is available, it
4750+
# returns CPU.
4751+
device = torch._C._get_accelerator()
4752+
if isinstance(device_ids, list):
4753+
opts.device_ids = device_ids
4754+
# use only the first device id
4755+
opts.device = torch.device(device.type, device_ids[0])
4756+
elif getattr(group, "bound_device_id", None) is not None:
4757+
# Use device id from `init_process_group(device_id=...)`
4758+
opts.device = group.bound_device_id # type: ignore[assignment]
4759+
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
4760+
opts.device = torch.device("cpu")
4761+
else:
4762+
# Use the current device set by the user. If user did not set any, this
4763+
# may use default device 0, causing issues like hang or all processes
4764+
# creating context on device 0.
4765+
opts.device = device
4766+
warnings.warn( # warn only once
4767+
"No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. "
4768+
)
47554769

4756-
group = group or _get_default_group()
47574770
work = group.barrier(opts=opts)
47584771

47594772
if async_op:

0 commit comments

Comments
 (0)
0