8000 [2d] add gradient hook to FSDP extension · pytorch/pytorch@ce3bc85 · GitHub
[go: up one dir, main page]

Skip to content

Commit ce3bc85

Browse files
committed
[2d] add gradient hook to FSDP extension
ghstack-source-id: 17a723a Pull Request resolved: #116122
1 parent cbc70e9 commit ce3bc85

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed
Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,43 @@
1+
from functools import partial
12
from typing import Optional, Tuple
23

34
import torch
4-
from torch.distributed._tensor import DTensor as DistributedTensor
5+
from torch.distributed._tensor import DTensor
56
from torch.distributed._tensor.placement_types import DTensorSpec
67

78

9+
def grad_layout_hook(param_placements, grad):
10+
# a gradient hook to ensure the gradient layout is the same as
11+
# the parameter layout, this is due to the fact that our current
12+
# FSDP have implicit assumption that param/grad sharding layout
13+
# should be the same after backward. However this is not always
14+
# the case for DTensor, i.e. we might have a replicated param
15+
# and a partial gradient and DTensor was relying on optimizer
16+
# who really consumes the gradient to convert the layout.
17+
if isinstance(grad, DTensor) and grad.placements != param_placements:
18+
grad = grad.redistribute(grad.device_mesh, param_placements)
19+
return grad
20+
21+
822
def _flatten_tensor(
923
tensor: torch.Tensor,
1024
) -> Tuple[torch.Tensor, Optional[DTensorSpec]]:
11-
if isinstance(tensor, DistributedTensor):
25+
if isinstance(tensor, DTensor):
1226
tensor._local_tensor.requires_grad_()
1327
return tensor._local_tensor, tensor._spec
1428
return tensor, None
1529

1630

31+
@torch._dynamo.disable
1732
def _unflatten_tensor(tensor: torch.Tensor, spec: DTensorSpec) -> torch.Tensor:
18-
result = DistributedTensor.from_local(
33+
# unflatten would mainly be called everytime FSDP allgather parameters.
34+
result = DTensor.from_local(
1935
tensor,
2036
spec.mesh,
2137
spec.placements,
2238
run_check=False,
2339
)
40+
if result.requires_grad:
41+
# only register the hook if the tensor requires grad
42+
result.register_hook(partial(grad_layout_hook, spec.placements))
2443
return result

0 commit comments

Comments
 (0)
0