8000 Fix support for nccl < 2.17 by oraluben · Pull Request #145719 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Fix support for nccl < 2.17 #145719

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

Fix support for nccl < 2.17 #145719

wants to merge 22 commits into from

Conversation

oraluben
Copy link
Contributor
@oraluben oraluben commented Jan 27, 2025

Fix build failure with older (< 2.17) NCCL.

Refactoring NCCL version related code:

  1. Fix failure against old NCCL versions since [PGNCCL] Use non-blocking mode by default in eager init #138527 cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o ;
  2. remove unused checks caused by unsupported NCCL version (since there's a static assert checking NCCL >= 2.7: [rfc][be] static assert that nccl version is >= 2.7 #142023);
  3. move NCCL macros to torch/csrc/cuda/nccl.h from various places and uniform some style (#if to #ifdef), which could improve maintainability of the NCCL part I hope.

Resolves #141914

cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Jan 27, 2025
Copy link
pytorch-bot bot commented Jan 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/145719

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit ebf3f48 with merge base 762724f (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@oraluben
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Jan 27, 2025
@oraluben oraluben marked this pull request as ready for review January 27, 2025 08:08
@wconstab
Copy link
Contributor

nit: "(since there's a static assert checking NCCL >= 2.7:"
I think you had a typo in your PR desc, it should be >= 2.4 right?

@oraluben
Copy link
Contributor Author
oraluben commented Jan 28, 2025

it should be >= 2.4 right?

No it's already 2.7 now:

static_assert(
(NCCL_MAJOR == 2 && NCCL_MINOR >= 7) || (NCCL_MAJOR > 2),
"NCCL version must be 2.7 or later");

I also got a little bit confused when seeing 2.4 vs 2.7 in #141914, looks like 2.4 is the typo?

Actually I didn't make it to find a nccl < 2.8 to validate if 2.7 really works, and same for 2.4. (2.8 tested)

@c-p-i-o
Copy link
Contributor
c-p-i-o commented Jan 28, 2025

oops sorry about the confusion about 2.4 v/s 2.7.

Also, when I tried to simplify the code in #141914 - I too ran into test timeouts.
So there's definitely something nefarious going on that needs to be looked at to get the failing tests to pass.

@oraluben
Copy link
Contributor Author

Also, when I tried to simplify the code in #141914 - I too ran into test timeouts.
So there's definitely something nefarious going on that needs to be looked at to get the failing tests to pass.

Do you plan to fix this? I can wait for your PR to be merged first, or I can also try to resolve it, the failure seems stable on CUDA 11.8.

@oraluben
Copy link
Contributor Author

@c-p-i-o I've verified that the update should fix the failure. The cause is that torch/csrc/distributed/c10d/NCCLUtils.hpp and torch/csrc/distributed/c10d/quantization/quantization_gpu.cu don't include the header contains the checks before.

@c-p-i-o
Copy link
Contributor
c-p-i-o commented Jan 28, 2025

Do you plan to fix this? I can wait for your PR to be merged first, or I can also try to resolve it, the failure seems stable on CUDA 11.8.

Go ahead and land your PR! I'll abandon mine - no problem!

@c-p-i-o
Copy link
Contributor
c-p-i-o commented Jan 29, 2025

The remaining test failure needs to be investigated:

distributed/algorithms/quantization/test_quantization.py::DistQuantizationTests::test_all_to_all_bfp16
Traceback (most recent call last):
  File "/var/lib/jenkins/workspace/test/distributed/algorithms/quantization/test_quantization.py", line 314, in <module>
    run_tests()
  File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 1266, in run_tests
    assert len(failed_tests) == 0, "{} unit test(s) failed:\n\t{}".format(
AssertionError: 1 unit test(s) failed:
	distributed/algorithms/quantization/test_quantization.py::DistQuantizationTests::test_all_to_all_bfp1

https://github.com/pytorch/pytorch/actions/runs/13013038548/job/36309457605?pr=145719

@oraluben
Copy link
Contributor Author
oraluben commented Jan 29, 2025

Looks like the test itself has bug. The macro was not working as expected before, so bf16 support of nccl was never enabled?

bfp16 uses _FloatToBfloat16Quantized, which uses fp16 on cpu and bf16 on gpu if supported. Previous the bf16 support of nccl was never enabled.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-focal-rocm6.3-py3.10 / test (distributed, 1, 1, linux.rocm.gpu.4)

Details for Dev Infra team Raised by workflow job

@oraluben
Copy link
Contributor Author
oraluben commented Feb 8, 2025

ping, we need another approve to run lint here :)

@oraluben
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merge 6D40 d once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch merge --squash __pull-request-145719__init__ returned non-zero exit code 1

Auto-merging torch/csrc/distributed/c10d/NCCLUtils.cpp
Auto-merging torch/csrc/distributed/c10d/NCCLUtils.hpp
Auto-merging torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
Squash commit -- not updating HEAD
Automatic merge failed; fix conflicts and then commit the result.
Details for Dev Infra team Raised by workflow job

@oraluben
Copy link
Contributor Author
oraluben commented Mar 1, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/145719/head returned non-zero exit code 1

Rebasing (1/15)
Auto-merging torch/csrc/distributed/c10d/NCCLUtils.cpp
Auto-merging torch/csrc/distributed/c10d/NCCLUtils.hpp
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/NCCLUtils.hpp
Auto-merging torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
CONFLICT (content): Merge conflict in torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp
error: could not apply 580a675a0b0... Support nccl >= 2.7
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply 580a675a0b0... Support nccl >= 2.7

Raised by https://github.com/pytorch/pytorch/actions/runs/13602354396

@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Mar 1, 2025
@kwen2501
Copy link
Contributor

Hi, just wondering if we still have build issue for < 2.17?

@oraluben
Copy link
Contributor Author

Hi, just wondering if we still have build issue for < 2.17?

Didn't test on main but there looks like so.

@oraluben
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 11, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / cuda12.4-py3.10-gcc9-sm80 / build

Details for Dev Infra team Raised by workflow job

Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label May 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (c10d) release notes category Stale topic: not user facing topic category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0