diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index c14b4be2acf10f..bafe34be772c14 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -8,7 +8,14 @@ from torch import distributed as dist from torch.distributed._shard.sharded_tensor.api import ShardedTensor from torch.distributed._shard.sharding_spec import ChunkShardingSpec -from torch.distributed._tensor import DeviceMesh, DTensor as DT, init_device_mesh, Shard +from torch.distributed._tensor import ( + DeviceMesh, + distribute_module, + DTensor, + init_device_mesh, + Shard, +) +from torch.distributed._tensor.debug import CommDebugMode from torch.distributed.fsdp.fully_sharded_data_parallel import ( CPUOffload, FullyShardedDataParallel as FSDP, @@ -26,6 +33,7 @@ run_tests, TEST_WITH_DEV_DBG_ASAN, ) +from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) @@ -68,6 +76,34 @@ def get_non_sharded_param_names() -> List[str]: return ["net3.weight", "net3.bias"] +# simple RMSNorm layer for testing +class RMSNormPython(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.weight = torch.nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + output = self._norm(x) + return output * self.weight + + +def distribute_rmsnorm(module, device_mesh): + def prepare_input_fn(inputs, device_mesh): + shard_tensor = DTensor.from_local(inputs[0], device_mesh, [Shard(0)]) + return shard_tensor + + def prepare_output_fn(outputs, device_mesh): + return outputs.to_local() + + return distribute_module( + module, device_mesh, input_fn=prepare_input_fn, output_fn=prepare_output_fn + ) + + class TestTPFSDPIntegration(FSDPTest): def _get_params_and_sharding_info( self, @@ -260,8 +296,8 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload): sequence_parallelize_plan, ) tp_pg = mesh_2d["tp"].get_group(mesh_dim=0) - assert isinstance(tp_fsdp_model.net1.weight, DT) - assert isinstance(tp_fsdp_model.net2.weight, DT) + assert isinstance(tp_fsdp_model.net1.weight, DTensor) + assert isinstance(tp_fsdp_model.net2.weight, DTensor) tp_fsdp_model = FSDP( tp_fsdp_model, cpu_offload=cpu_offload, @@ -314,6 +350,51 @@ def test_fsdp_tp_integration(self, tensor_parallel_size, cpu_offload): tp_fsdp_out = tp_fsdp_model(inp) self.assertEqual(fsdp_out, tp_fsdp_out) + @skip_if_lt_x_gpu(4) + def test_fsdp_tp_gradient_layout(self): + """ + Tests TP + FSDP extension with consistent gradient layout + """ + mesh_2d = init_device_mesh( + "cuda", (self.world_size // 2, 2), mesh_dim_names=["dp", "tp"] + ) + + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.mlp = MLPModule("cuda") + self.mlp_norm = RMSNormPython(10) + + def forward(self, x): + return self.mlp(self.mlp_norm(x)) + + model = TestModel().cuda(self.rank) + + # Shard with TP and test gradient + tp_mesh = mesh_2d["tp"] + tp_model = parallelize_module( + model, + tp_mesh, + { + "mlp.net1": ColwiseParallel(input_layouts=Shard(0)), + "mlp.net2": RowwiseParallel(output_layouts=Shard(0)), + }, + ) + distribute_rmsnorm(tp_model.mlp_norm, tp_mesh) + + fsdp_2d_model = FSDP(tp_model, device_mesh=mesh_2d["dp"]) + comm_mode = CommDebugMode() + + with comm_mode: + fsdp_2d_model(torch.rand(2, 10).cuda(self.rank)).sum().backward() + + funcol = torch.ops.c10d_functional + comm_counts = comm_mode.get_comm_counts() + self.assertEqual(comm_mode.get_total_counts(), 5) + self.assertEqual(comm_counts[funcol.reduce_scatter_tensor], 2) + self.assertEqual(comm_counts[funcol.all_gather_into_tensor], 2) + self.assertEqual(comm_counts[funcol.all_reduce], 1) + instantiate_parametrized_tests(TestTPFSDPIntegration) diff --git a/torch/distributed/tensor/parallel/_data_parallel_utils.py b/torch/distributed/tensor/parallel/_data_parallel_utils.py index d19258d93ee353..1edc00374ae3bc 100644 --- a/torch/distributed/tensor/parallel/_data_parallel_utils.py +++ b/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -1,24 +1,34 @@ from typing import Optional, Tuple import torch -from torch.distributed._tensor import DTensor as DistributedTensor +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed._tensor import DTensor from torch.distributed._tensor.placement_types import DTensorSpec +def sync_grad_hook(grad): + if isinstance(grad, AsyncCollectiveTensor): + grad.wait() + return grad + + def _flatten_tensor( tensor: torch.Tensor, ) -> Tuple[torch.Tensor, Optional[DTensorSpec]]: - if isinstance(tensor, DistributedTensor): + if isinstance(tensor, DTensor): tensor._local_tensor.requires_grad_() return tensor._local_tensor, tensor._spec return tensor, None def _unflatten_tensor(tensor: torch.Tensor, spec: DTensorSpec) -> torch.Tensor: - result = DistributedTensor.from_local( + # unflatten would mainly be called everytime FSDP allgather parameters. + result = DTensor.from_local( tensor, spec.mesh, spec.placements, run_check=False, ) + if tensor.requires_grad: + tensor.register_hook(sync_grad_hook) return result