10000 [c10d] Fix extra CUDA context created by barrier by kwen2501 · Pull Request #149144 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[c10d] Fix extra CUDA context created by barrier #149144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

kwen2501
Copy link
Contributor
@kwen2501 kwen2501 commented Mar 13, 2025

Stack from ghstack (oldest at bottom):

Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses device_id given by user
when calling init_process_group.

This PR also uses torch._C._get_accelerator() to determine the device
type.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

[ghstack-poisoned]
Copy link
pytorch-bot bot commented Mar 13, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149144

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 731b4cd with merge base e9e1aac (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Mar 13, 2025
kwen2501 added a commit that referenced this pull request Mar 13, 2025
Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

ghstack-source-id: 258020d
Pull Request resolved: #149144
@kwen2501 kwen2501 requested review from H-Huang, wconstab and fduwjj March 13, 2025 18:39
@kwen2501 kwen2501 added the topic: bug fixes topic category label Mar 13, 2025
Copy link
Contributor
@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thx for the fix!

elif group.bound_device_id is not None:
# Use device id from `init_process_group(device_id=...)`
opts.device = group.bound_device_id
elif device.type == "cpu" or get_backend(group) == Backend.GLOO:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way to avoid depending on specific backend names/types? This makes it hard to add new ones that are compatible with core PT -- I've been trying to clean these up for torchft

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I hope there is a way. The specific code is for a case where the user is on a GPU machine but only want to use CPU to do some stuff...

@@ -4654,30 +4654,43 @@ def barrier(
group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
async_op (bool, optional): Whether this op should be an async op
device_ids ([int], optional): List of device/GPU ids.
device_ids ([int], optional): List of device/GPU ids. Only one id is expected.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we change this to

Only the first ID is used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do mean only one is expected, because now we are expecting one device per thread. Some of the API signatures came from the old days.

# Use device id from `init_process_group(device_id=...)`
opts.device = group.bound_device_id
elif device.type == "cpu" or get_backend(group) == Backend.GLOO:
opts.device = torch.device("cpu")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will Gloo fail if it's not a CPU device?

Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 13, 2025
Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

ghstack-source-id: 428f13a
Pull Request resolved: #149144
@kwen2501
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 17, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: linux-binary-manywheel / manywheel-py3_9-cuda12_8-test / test

Details for Dev Infra team Raised by workflow job

@kwen2501
Copy link
Contributor Author

Failure seems to be an issue of CI instance and unrelated.
@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: linux-binary-manywheel / manywheel-py3_9-cuda12_8-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 070865253dc02320a0c6bf7395e44556a6f4998c returned non-zero exit code 1

Auto-merging test/distributed/test_c10d_nccl.py
Auto-merging torch/distributed/distributed_c10d.py
CONFLICT (content): Merge conflict in torch/distributed/distributed_c10d.py
error: could not apply 070865253dc... [c10d] Fix extra CUDA context created by barrier
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

@cyyever
Copy link
Collaborator
cyyever commented Mar 29, 2025

@pytorchbot merge -f "Unrelated failures"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch cherry-pick -x 070865253dc02320a0c6bf7395e44556a6f4998c returned non-zero exit code 1

Auto-merging test/distributed/test_c10d_nccl.py
Auto-merging torch/distributed/distributed_c10d.py
CONFLICT (content): Merge conflict in torch/distributed/distributed_c10d.py
error: could not apply 070865253dc... [c10d] Fix extra CUDA context created by barrier
hint: After resolving the conflicts, mark them with
hint: "git add/rm <pathspec>", then run
hint: "git cherry-pick --continue".
hint: You can instead skip this commit with "git cherry-pick --skip".
hint: To abort and get back to the state before "git cherry-pick",
hint: run "git cherry-pick --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Details for Dev Infra team Raised by workflow job

Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request May 2, 2025
Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

ghstack-source-id: 96c32b9
Pull Request resolved: #149144
@kwen2501
Copy link
Contributor Author
kwen2501 commented May 3, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

kwen2501 added a commit that referenced this pull request May 5, 2025
Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

ghstack-source-id: 96c32b9
Pull Request resolved: #149144
@huydhn
Copy link
Contributor
huydhn commented May 5, 2025

@pytorchbot revert -m 'Internal failure looks legit' -c ghfirst

E           Traceback of where the remote function was issued on controller (most recent call last):
E             <not related to a specific invocation>
E           Traceback of where the remote function failed on worker (most recent call last):
E             File "<unknown>", line None, in remote function failed: Traceback (most recent call last):
E             File "/re_cwd/buck-out/v2/gen/fbcode/d6a07959eaa8ed59/monarch/python/tests/__test_remote_functions__/test_remote_functions#link-tree/monarch/worker/_testing_function.py", line 101, in barrier
E               dist.barrier(group=group, async_op=False, device_ids=device_ids)
E             File "/re_cwd/buck-out/v2/gen/fbcode/d6a07959eaa8ed59/monarch/python/tests/__test_remote_functions__/test_remote_functions#link-tree/torch/distributed/c10d_logger.py", line 81, in wrapper
E               return func(*args, **kwargs)
E             File "/re_cwd/buck-out/v2/gen/fbcode/d6a07959eaa8ed59/monarch/python/tests/__test_remote_functions__/test_remote_functions#link-tree/torch/distributed/distributed_c10d.py", line 4770, in barrier
E               work = group.barrier(opts=opts)
E           <class 'RuntimeError'>: CUDA error: invalid device ordinal
E           CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E           For debugging consider passing CUDA_LAUNCH_BLOCKING=1
E           Device-side assertion tracking was not enabled by user.
E           RuntimeError: remote function failed: Traceback (most recent call last):
E             File "/re_cwd/buck-out/v2/gen/fbcode/d6a07959eaa8ed59/monarch/python/tests/__test_remote_functions__/test_remote_functions#link-tree/monarch/worker/_testing_function.py", line 101, in barrier
E               dist.barrier(group=group, async_op=False, device_ids=device_ids)
E             File "/re_cwd/buck-out/v2/gen/fbcode/d6a07959eaa8ed59/monarch/python/tests/__test_remote_functions__/test_remote_functions#link-tree/torch/distributed/c10d_logger.py", line 81, in wrapper
E               return func(*args, **kwargs)
E             File "/re_cwd/buck-out/v2/gen/fbcode/d6a07959eaa8ed59/monarch/python/tests/__test_remote_functions__/test_remote_functions#link-tree/torch/distributed/distributed_c10d.py", line 4770, in barrier
E               work = group.barrier(opts=opts)
E           <class 'RuntimeError'>: CUDA error: invalid device ordinal
E           CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
E           For debugging consider passing CUDA_LAUNCH_BLOCKING=1
E           Device-side assertion tracking was not enabled by user.
F438

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request May 5, 2025
This reverts commit 457fa82.

Reverted #149144 on behalf of https://github.com/huydhn due to Internal failure looks legit ([comment](#149144 (comment)))
@pytorchmergebot
Copy link
Collaborator

@kwen2501 your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels May 5, 2025
@kwen2501
Copy link
Contributor Author
kwen2501 commented May 6, 2025

@pytorchbot merge -f "Internal test was wrong; OSS version of barrier tests passed"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

github-merge-queue bot pushed a commit to intel/torch-xpu-ops that referenced this pull request May 9, 2025
Refer pytorch/pytorch#149144, Currently,
`dist.barrier` accepts `device_ids` as a parameter that doesn't have to
be a list. When `device_ids` is not provided or another value is passed,
`barrier` will use the device associated with the process group at
initialization to perform the synchronization.
atalman pushed a commit that referenced this pull request May 27, 2025
Fixes #149119.

In ProcessGroup.hpp, we create a dummy tensor for dispatching. This
requires a correct device index. This PR uses `device_id` given by user
when calling `init_process_group`.

This PR also uses `torch._C._get_accelerator()` to determine the device
type.

ghstack-source-id: 96c32b9
Pull Request resolved: #149144
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category Reverted topic: bug fixes topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants
0