8000 [2d] Ensure gradient clear out pending AsyncCollectiveTensor in FSDP Extension by wanchaol · Pull Request #116122 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[2d] Ensure gradient clear out pending AsyncCollectiveTensor in FSDP Extension #116122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
87 changes: 84 additions & 3 deletions test/distributed/fsdp/test_fsdp_tp_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
from torch import distributed as dist
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._tensor import DeviceMesh, DTensor as DT, init_device_mesh, Shard
from torch.distributed._tensor import (
DeviceMesh,
distribute_module,
DTensor,
init_device_mesh,
Shard,
)
from torch.distributed._tensor.debug import CommDebugMode
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
FullyShardedDataParallel as FSDP,
Expand All @@ -26,6 +33,7 @@
run_tests,
TEST_WITH_DEV_DBG_ASAN,
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule

if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
Expand Down Expand Up @@ -68,6 +76,34 @@ def get_non_sharded_param_names() -> List[str]:
return ["net3.weight", "net3.bias"]


# simple RMSNorm layer for testing
class RMSNormPython(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
output = self._norm(x)
return output * self.weight


def distribute_rmsnorm(module, device_mesh):
def prepare_input_fn(inputs, device_mesh):
shard_tensor = DTensor.from_local(inputs[0], device_mesh, [Shard(0)])
return shard_tensor

def prepare_output_fn(outputs, device_mesh):
return outputs.to_local()

return distribute_module(
module, device_mesh, input_fn=prepare_input_fn, output_fn=prepare_output_fn
)


class TestTPFSDPIntegration(FSDPTest):
def _get_params_and_sharding_info(
self,
Expand Down Expand Up @@ -260,8 +296,8 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
sequence_parallelize_plan,
)
tp_pg = mesh_2d["tp"].get_group(mesh_dim=0)
assert isinstance(tp_fsdp_model.net1.weight, DT)
assert isinstance(tp_fsdp_model.net2.weight, DT)
assert isinstance(tp_fsdp_model.net1.weight, DTensor)
assert isinstance(tp_fsdp_model.net2.weight, DTensor)
tp_fsdp_model = FSDP(
tp_fsdp_model,
cpu_offload=cpu_offload,
Expand Down Expand Up @@ -314,6 +350,51 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload):
tp_fsdp_out = tp_fsdp_model(inp)
self.assertEqual(fsdp_out, tp_fsdp_out)

@skip_if_lt_x_gpu(4)
def test_fsdp_tp_gradient_layout(self):
"""
Tests TP + FSDP extension with consistent gradient layout
"""
mesh_2d = init_device_mesh(
"cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"]
)

class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = MLPModule("cuda")
self.mlp_norm = RMSNormPython(10)

def forward(self, x):
return self.mlp(self.mlp_norm(x))

model = TestModel().cuda(self.rank)

# Shard with TP and test gradient
tp_mesh = mesh_2d["tp"]
tp_model = parallelize_module(
model,
tp_mesh,
{
"mlp.net1": ColwiseParallel(input_layouts=Shard(0)),
"mlp.net2": RowwiseParallel(output_layouts=Shard(0)),
},
)
distribute_rmsnorm(tp_model.mlp_norm, tp_mesh)

fsdp_2d_model = FSDP(tp_model, device_mesh=mesh_2d["dp"])
comm_mode = CommDebugMode()

with comm_mode:
fsdp_2d_model(torch.rand(2, 10).cuda(self.rank)).sum().backward()

funcol = torch.ops.c10d_functional
comm_counts = comm_mode.get_comm_counts()
self.assertEqual(comm_mode.get_total_counts(), 5)
self.assertEqual(comm_counts[funcol.reduce_scatter_tensor], 2)
self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 2)
self.assertEqual(comm_counts[funcol.all_reduce], 1)


instantiate_parametrized_tests(TestTPFSDPIntegration)

Expand Down
16 changes: 13 additions & 3 deletions torch/distributed/tensor/parallel/_data_parallel_utils.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,34 @@
from typing import Optional, Tuple

import torch
from torch.distributed._tensor import DTensor as DistributedTensor
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed._tensor import DTensor
from torch.distributed._tensor.placement_types import DTensorSpec


def sync_grad_hook(grad):
if isinstance(grad, AsyncCollectiveTensor):
grad.wait()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wanchaol Why do we discard the return value of grad.wait(), which unwraps the ACT? I.e., why not grad = grad.wait()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I think we should unwrap yes, it's a bug. Could u help submit a fix on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh wait, this looks like it was fixed on main already.

return grad


def _flatten_tensor(
tensor: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[DTensorSpec]]:
if isinstance(tensor, DistributedTensor):
if isinstance(tensor, DTensor):
tensor._local_tensor.requires_grad_()
return tensor._local_tensor, tensor._spec
return tensor, None


def _unflatten_tensor(tensor: torch.Tensor, spec: DTensorSpec) -> torch.Tensor:
result = DistributedTensor.from_local(
# unflatten would mainly be called everytime FSDP allgather parameters.
result = DTensor.from_local(
tensor,
spec.mesh,
spec.placements,
run_check=False,
)
if tensor.requires_grad:
tensor.register_hook(sync_grad_hook)
return result
0