8000 [2d] Ensure gradient clear out pending AsyncCollectiveTensor in FSDP Extension by wanchaol · Pull Request #116122 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[2d] Ensure gradient clear out pending AsyncCollectiveTensor in FSDP Extension #116122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from

Conversation

wanchaol
Copy link
Collaborator
@wanchaol wanchaol commented Dec 19, 2023

Stack from ghstack (oldest at bottom):

As titled, this PR adds gradient hook to the FSDP DTensor extension, to check if there's gradients that are AsyncCollectiveTensors, if there're some, we eagerly wait there.

This is needed because sometimes the parameter's gradient might still pending with AsyncCollectiveTensor, if we directly feed them to FSDP then FSDP would use the ACT's storage to do reduce_scatter, which is wrong.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Dec 19, 2023
Copy link
pytorch-bot bot commented Dec 19, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116122

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 9760f30 with merge base 0e63837 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Dec 19, 2023
@wanchaol wanchaol changed the title [2d] add gradient hook to FSDP extension [2d] Ensure param/grad sharding layout consistency in FSDP extension Dec 19, 2023
… extension"

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
8000
… extension"

cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Dec 19, 2023
ghstack-source-id: 17a723a
Pull Request resolved: #116122
… extension"

As titled, this PR adds gradient hook to the FSDP DTensor extension,
given the fact that FlatParam FSDP assumes same sharding layout after
backward pass, here we also ensure the DTensor gradients follow the same
sharding layout

[ghstack-poisoned]
wanchaol added a commit that referenced this pull request Dec 21, 2023
As titled, this PR adds gradient hook to the FSDP DTensor extension,
given the fact that FlatParam FSDP assumes same sharding layout after
backward pass, here we also ensure the DTensor gradients follow the same
sharding layout

ghstack-source-id: c2123b7
Pull Request resolved: #116122
@wanchaol wanchaol requested review from awgu and wz337 December 21, 2023 16:58
Copy link
Collaborator
@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

"mlp.net2": RowwiseParallel(output_layouts=Shard(0)),
},
)
distribute_rmsnorm(tp_model.mlp_norm, mesh_1d)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this means the RMSNorm runs as "sequence parallel"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this is sequence parallel layernorm/rmsnorm, the backward would produce partial dtensor, and the parameter would be replicated dtensor

Copy link
Contributor
@fduwjj fduwjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

… extension"

As titled, this PR adds gradient hook to the FSDP DTensor extension,
given the fact that FlatParam FSDP assumes same sharding layout after
backward pass, here we also ensure the DTensor gradients follow the same
sharding layout

[ghstack-poisoned]
… extension"

As titled, this PR adds gradient hook to the FSDP DTensor extension,
given the fact that FlatParam FSDP assumes same sharding layout after
backward pass, here we also ensure the DTensor gradients follow the same
sharding layout

[ghstack-poisoned]
… extension"

As titled, this PR adds gradient hook to the FSDP DTensor extension,
given the fact that FlatParam FSDP assumes same sharding layout after
backward pass, here we also ensure the DTensor gradients follow the same
sharding layout

[ghstack-poisoned]
@wanchaol wanchaol changed the title [2d] Ensure param/grad sharding layout consistency in FSDP extension [2d] Ensure gradient clear out pending AsyncCollectiveTensor in FSDP Extension Jan 2, 2024
@wanchaol
Copy link
Collaborator Author
wanchaol commented Jan 2, 2024

changed the purpose of this PR a bit to make the gradient hook to eagerly wait for the ACT.

pytorchmergebot pushed a commit that referenced this pull request Jan 2, 2024
@facebook-github-bot facebook-github-bot deleted the gh/wanchaol/417/head branch January 6, 2024 15:22 8000
@atalman atalman added this to the 2.2.1 milestone Jan 16, 2024
from torch.distributed._tensor.placement_types import DTensorSpec


def sync_grad_hook(grad):
if isinstance(grad, AsyncCollectiveTensor):
grad.wait()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol Why do we discard the return value of grad.wait(), which unwraps the ACT? I.e., why not grad = grad.wait()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I think we should unwrap yes, it's a bug. Could u help submit a fix on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, this looks like it was fixed on main already.

Skylion007 pushed a commit to Skylion007/pytorch that referenced this pull request Feb 12, 2024
…Extension (pytorch#116122)

As titled, this PR adds gradient hook to the FSDP DTensor extension, to check if there's gradients that are AsyncCollectiveTensors, if there're some, we eagerly wait there.

This is needed because sometimes the parameter's gradient might still pending with AsyncCollectiveTensor, if we directly feed them to FSDP then FSDP would use the ACT's storage to do reduce_scatter, which is wrong.
Pull Request resolved: pytorch#116122
Approved by: https://github.com/awgu, https://github.com/fduwjj
atalman pushed a commit that referenced this pull request Feb 14, 2024
Co-authored-by: Wanchao Liang <wanchaol@users.noreply.github.com>
resolved: #116122
resolved: #117020
fixes #117126
resolved: #117336
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants
0