|
| 1 | +from functools import partial |
1 | 2 | from typing import Optional, Tuple
|
2 | 3 |
|
3 | 4 | import torch
|
4 |
| -from torch.distributed._tensor import DTensor as DistributedTensor |
| 5 | +from torch.distributed._tensor import DTensor |
5 | 6 | from torch.distributed._tensor.placement_types import DTensorSpec
|
6 | 7 |
|
7 | 8 |
|
| 9 | +def grad_layout_hook(param_placements, grad): |
| 10 | + # a gradient hook to ensure the gradient layout is the same as |
| 11 | + # the parameter layout, this is due to the fact that our current |
| 12 | + # FSDP have implicit assumption that param/grad sharding layout |
| 13 | + # should be the same after backward. However this is not always |
| 14 | + # the case for DTensor, i.e. we might have a replicated param |
| 15 | + # and a partial gradient and DTensor was relying on optimizer |
| 16 | + # who really consumes the gradient to convert the layout. |
| 17 | + if isinstance(grad, DTensor) and grad.placements != param_placements: |
| 18 | + grad = grad.redistribute(grad.device_mesh, param_placements) |
| 19 | + return grad |
| 20 | + |
| 21 | + |
8 | 22 | def _flatten_tensor(
|
9 | 23 | tensor: torch.Tensor,
|
10 | 24 | ) -> Tuple[torch.Tensor, Optional[DTensorSpec]]:
|
11 |
| - if isinstance(tensor, DistributedTensor): |
| 25 | + if isinstance(tensor, DTensor): |
12 | 26 | tensor._local_tensor.requires_grad_()
|
13 | 27 | return tensor._local_tensor, tensor._spec
|
14 | 28 | return tensor, None
|
15 | 29 |
|
16 | 30 |
|
| 31 | +@torch._dynamo.disable |
17 | 32 | def _unflatten_tensor(tensor: torch.Tensor, spec: DTensorSpec) -> torch.Tensor:
|
18 |
| - result = DistributedTensor.from_local( |
| 33 | + # unflatten would mainly be called everytime FSDP allgather parameters. |
| 34 | + result = DTensor.from_local( |
19 | 35 | tensor,
|
20 | 36 | spec.mesh,
|
21 | 37 | spec.placements,
|
22 | 38 | run_check=False,
|
23 | 39 | )
|
| 40 | + if result.requires_grad: |
| 41 | + # only register the hook if the tensor requires grad |
| 42 | + result.register_hook(partial(grad_layout_hook, spec.placements)) |
24 | 43 | return result
|
0 commit comments