-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[2d] Ensure gradient clear out pending AsyncCollectiveTensor in FSDP Extension #116122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/116122
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 9760f30 with merge base 0e63837 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
… extension" cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
… extension" cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 awgu penguinwu fegin XilunWu fduwjj wz337 tianyu-l wconstab yf225 [ghstack-poisoned]
… extension" As titled, this PR adds gradient hook to the FSDP DTensor extension, given the fact that FlatParam FSDP assumes same sharding layout after backward pass, here we also ensure the DTensor gradients follow the same sharding layout [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
"mlp.net2": RowwiseParallel(output_layouts=Shard(0)), | ||
}, | ||
) | ||
distribute_rmsnorm(tp_model.mlp_norm, mesh_1d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this means the RMSNorm runs as "sequence parallel"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this is sequence parallel layernorm/rmsnorm, the backward would produce partial dtensor, and the parameter would be replicated dtensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… extension" As titled, this PR adds gradient hook to the FSDP DTensor extension, given the fact that FlatParam FSDP assumes same sharding layout after backward pass, here we also ensure the DTensor gradients follow the same sharding layout [ghstack-poisoned]
… extension" As titled, this PR adds gradient hook to the FSDP DTensor extension, given the fact that FlatParam FSDP assumes same sharding layout after backward pass, here we also ensure the DTensor gradients follow the same sharding layout [ghstack-poisoned]
… extension" As titled, this PR adds gradient hook to the FSDP DTensor extension, given the fact that FlatParam FSDP assumes same sharding layout after backward pass, here we also ensure the DTensor gradients follow the same sharding layout [ghstack-poisoned]
changed the purpose of this PR a bit to make the gradient hook to eagerly wait for the ACT. |
Pull Request resolved: #116244 Approved by: https://github.com/awgu, https://github.com/wz337, https://github.com/fduwjj ghstack dependencies: #116122
from torch.distributed._tensor.placement_types import DTensorSpec | ||
|
||
|
||
def sync_grad_hook(grad): | ||
if isinstance(grad, AsyncCollectiveTensor): | ||
grad.wait() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wanchaol Why do we discard the return value of grad.wait()
, which unwraps the ACT? I.e., why not grad = grad.wait()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I think we should unwrap yes, it's a bug. Could u help submit a fix on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh wait, this looks like it was fixed on main already.
…Extension (pytorch#116122) As titled, this PR adds gradient hook to the FSDP DTensor extension, to check if there's gradients that are AsyncCollectiveTensors, if there're some, we eagerly wait there. This is needed because sometimes the parameter's gradient might still pending with AsyncCollectiveTensor, if we directly feed them to FSDP then FSDP would use the ACT's storage to do reduce_scatter, which is wrong. Pull Request resolved: pytorch#116122 Approved by: https://github.com/awgu, https://github.com/fduwjj
Stack from ghstack (oldest at bottom):
As titled, this PR adds gradient hook to the FSDP DTensor extension, to check if there's gradients that are AsyncCollectiveTensors, if there're some, we eagerly wait there.
This is needed because sometimes the parameter's gradient might still pending with AsyncCollectiveTensor, if we directly feed them to FSDP then FSDP would use the ACT's storage to do reduce_scatter, which is wrong.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @fduwjj @wz337 @tianyu-l @wconstab @yf225