-
Notifications
You must be signed in to change notification settings - Fork 24.3k
[2d] Ensure gradient clear out pending AsyncCollectiveTensor in FSDP Extension #116122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf3f6e8
4c54da3
01879d5
b9f933d
7d8e388
bba67dd
9760f30
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,24 +1,34 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
from torch.distributed._tensor import DTensor as DistributedTensor | ||
from torch.distributed._functional_collectives import AsyncCollectiveTensor | ||
from torch.distributed._tensor import DTensor | ||
from torch.distributed._tensor.placement_types import DTensorSpec | ||
|
||
|
||
def sync_grad_hook(grad): | ||
if isinstance(grad, AsyncCollectiveTensor): | ||
grad.wait() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @wanchaol Why do we discard the return value of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch! I think we should unwrap yes, it's a bug. Could u help submit a fix on this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh wait, this looks like it was fixed on main already. |
||
return grad | ||
|
||
|
||
def _flatten_tensor( | ||
tensor: torch.Tensor, | ||
) -> Tuple[torch.Tensor, Optional[DTensorSpec]]: | ||
if isinstance(tensor, DistributedTensor): | ||
if isinstance(tensor, DTensor): | ||
tensor._local_tensor.requires_grad_() | ||
return tensor._local_tensor, tensor._spec | ||
return tensor, None | ||
|
||
|
||
def _unflatten_tensor(tensor: torch.Tensor, spec: DTensorSpec) -> torch.Tensor: | ||
result = DistributedTensor.from_local( | ||
# unflatten would mainly be called everytime FSDP allgather parameters. | ||
result = DTensor.from_local( | ||
tensor, | ||
spec.mesh, | ||
spec.placements, | ||
run_check=False, | ||
) | ||
if tensor.requires_grad: | ||
tensor.register_hook(sync_grad_hook) | ||
return result |
Uh oh!
There was an error while loading. Please reload this page.