-
Notifications
You must be signed in to change notification settings - Fork 24.7k
[PT][FSDP] support custom all reduce hook across FSDP units #147114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147114
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 5 PendingAs of commit ea6b08f with merge base 6f035d8 ( UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This pull request was exported from Phabricator. Differential Revision: D68255583 |
64e2a89
to
e2a8844
Compare
This pull request was exported from Phabricator. Differential Revision: D68255583 |
e2a8844
to
3180df8
Compare
3180df8
to
46b8c00
Compare
This pull request was exported from Phabricator. Differential Revision: D68255583 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D68255583 |
46b8c00
to
eaf3733
Compare
This pull request was exported from Phabricator. Differential Revision: D68255583 |
eaf3733
to
ad8b18b
Compare
This pull request was exported from Phabricator. Differential Revision: D68255583 |
ad8b18b
to
4efe755
Compare
@xunnanxu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge -r main |
@pytorchbot started a rebase job onto refs/remotes/origin/main. Check the current status here |
Successfully rebased |
5327375
to
cf5c903
Compare
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
@xunnanxu has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@pytorchbot merge |
Merge failedReason: This PR has internal changes and must be landed via Phabricator! Please try reimporting/rexporting the PR! Details for Dev Infra teamRaised by workflow job |
…#147114) Summary: This change adds an API `set_all_reduce_hook` to the `FSDPModule` to support customized all reduce either in native HSDP (2d mesh) setup or custom HSDP (1d FSDP + custom AR across replicas) * For native HSDP, the original AR would still run as is and this hook allows for additional gradient modification post all reduce. * For custom HSDP, the original AR will be skipped and all the logic is instead expected to be executed in the hook. The custom hook is expected to perform operations in place (no return value). Example basic usage: ``` model = ... fully_shard(model, mesh=...) model.set_all_reduce_hook(my_hook) ``` By default, the hook will run in the default all reduce stream post reduce scatter. When native HSDP is NOT enabled, the custom hook can be specified to run in a custom stream. This custom stream will also be synchronized post reduce scatter similarly. See tests for examples. Test Plan: CI cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o Reviewed By: awgu, ckluk2 Differential Revision: D68255583 Pulled By: xunnanxu
cf5c903
to
ea6b08f
Compare
This pull request was exported from Phabricator. Differential Revision: D68255583 |
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This change adds an API `set_all_reduce_hook` to the `FSDPModule` to support customized all reduce either in native HSDP (2d mesh) setup or custom HSDP (1d FSDP + custom AR across replicas) * For native HSDP, the original AR would still run as is and this hook allows for additional gradient modification post all reduce. * For custom HSDP, the original AR will be skipped and all the logic is instead expected to be executed in the hook. The custom hook is expected to perform operations in place (no return value). Example basic usage: ``` model = ... fully_shard(model, mesh=...) model.set_all_reduce_hook(my_hook) ``` By default, the hook will run in the default all reduce stream post reduce scatter. When native HSDP is NOT enabled, the custom hook can be specified to run in a custom stream. This custom stream will also be synchronized post reduce scatter similarly. See tests for examples. Test Plan: CI Differential Revision: D68255583 Pull Request resolved: #147114 Approved by: https://github.com/awgu
This change adds an API `set_all_reduce_hook` to the `FSDPModule` to support customized all reduce either in native HSDP (2d mesh) setup or custom HSDP (1d FSDP + custom AR across replicas) * For native HSDP, the original AR would still run as is and this hook allows for additional gradient modification post all reduce. * For custom HSDP, the original AR will be skipped and all the logic is instead expected to be executed in the hook. The custom hook is expected to perform operations in place (no return value). Example basic usage: ``` model = ... fully_shard(model, mesh=...) model.set_all_reduce_hook(my_hook) ``` By default, the hook will run in the default all reduce stream post reduce scatter. When native HSDP is NOT enabled, the custom hook can be specified to run in a custom stream. This custom stream will also be synchronized post reduce scatter similarly. See tests for examples. Test Plan: CI Differential Revision: D68255583 Pull Request resolved: #147114 Approved by: https://github.com/awgu
This change adds an API
set_all_reduce_hook
to theFSDPModule
tosupport customized all reduce either in native HSDP (2d mesh) setup or custom HSDP (1d FSDP + custom AR across replicas)
The custom hook is expected to perform operations in place (no return value).
Example basic usage:
By default, the hook will run in the default all reduce stream post reduce scatter.
When native HSDP is NOT enabled, the custom hook can be specified to run in a custom stream. This custom stream will also be synchronized post reduce scatter similarly. See tests for examples.
Test Plan: CI
Differential Revision: D68255583
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o