8000 Update on "[FSDP][dtensor] use _StridedShard to represent nested shar… · pytorch/pytorch@c2ce004 · GitHub
[go: up one dir, main page]

Skip to content

Commit c2ce004

Browse files
committed
Update on "[FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result"
Fixes issue #129229 #129206 **Summary** 1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding 2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result 3. Re-enabled the tests that were disabled in #129519 **test** `pytest test/distributed/_composable/fsdp/` `pytest test/distributed/_composable/test_composability/test_2d_composability.py` `pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py` cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114) [ghstack-poisoned]
2 parents 40bd6ef + 566eb66 commit c2ce004

File tree

5 files changed

+21
-29
lines changed

5 files changed

+21
-29
lines changed

test/distributed/_composable/fsdp/test_fully_shard_init.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def test_shard_dtensor_parameters(self):
387387
)
388388
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
389389
# Use odd dim sizes to test uneven shards
390-
# TODO change "mlp_dim" back to 8 when uneven sharding
390+
# TODO: change "mlp_dim" back to 9 when uneven sharding
391391
# is supported for FSDP+TP
392392
model = MLP(8, dim_multiplier=3)
393393
orig_params = [param.detach().clone() for param in model.parameters()]
@@ -583,9 +583,7 @@ def test_meta_device_2d_init(self):
583583
dp_mesh, tp_mesh = global_mesh["dp"], global_mesh["tp"]
584584

585585
# Test both even sharding (8) and uneven sharding (3)
586-
# TODO change "mlp_dim" back to (8, 3) when uneven sharding
587-
# is supported for FSDP+TP
588-
for mlp_dim in (8, 4):
586+
for mlp_dim in (8, 3):
589587
with torch.device("meta"):
590588
model = MLP(mlp_dim, with_buffer=True)
591589
for param in model.parameters():

test/distributed/_composable/fsdp/test_fully_shard_state_dict.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -177,19 +177,10 @@ def _test_dp_tp_state_dict_save_load(self, global_mesh: DeviceMesh, mlp_dim: int
177177
"2.out_proj": RowwiseParallel(),
178178
},
179179
)
180-
# TODO: remove ``assertRaisesRegex`` once uneven sharding is supported
181-
if mlp_dim % dp_mesh.size() != 0:
182-
with self.assertRaisesRegex(
183-
NotImplementedError, "does not support uneven sharding"
184-
):
185-
for mlp in model:
186-
fully_shard(mlp, mesh=dp_mesh)
187-
fully_shard(model, mesh=dp_mesh)
188-
else:
189-
for mlp in model:
190-
fully_shard(mlp, mesh=dp_mesh)
191-
fully_shard(model, mesh=dp_mesh)
192-
self._test_state_dict_save_load(model)
180+
for mlp in model:
181+
fully_shard(mlp, mesh=dp_mesh)
182+
fully_shard(model, mesh=dp_mesh)
183+
self._test_state_dict_save_load(model)
193184

194185
def _test_state_dict_save_load(self, model: nn.Module):
195186
for param_name, param in model.named_parameters():

test/distributed/_composable/fsdp/test_fully_shard_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ def test_2d_mlp_with_nd_mesh(self):
993993
{
994994
"reshard_after_forward": [False, True],
995995
"use_activation_checkpointing": [False, True],
996-
# TODO change "mlp_dim" back to [3, 16, 17] when uneven sharding
996+
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
997997
# is supported for FSDP+TP
998998
"mlp_dim": [4, 16, 20],
999999
"foreach": [False],

test/distributed/_composable/test_composability/test_2d_composability.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def test_train_parity_2d_mlp(self):
6060
{
6161
"reshard_after_forward": [False, True],
6262
"use_activation_checkpointing": [False, True],
63-
# TODO change "mlp_dim" back to [3, 16, 17] when uneven sharding
63+
# TODO: change "mlp_dim" back to [3, 16, 17] when uneven sharding
6464
# is supported for FSDP+TP
6565
"mlp_dim": [4, 16, 20],
6666
},

torch/distributed/_composable/fsdp/_fsdp_param.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -302,16 +302,19 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
302302
)
303303
# NOTE: FSDP+TP does not support uneven sharding for now
304304
# TODO: enable uneven sharding for FSDP+TP
305-
num_shards_map = self._sharding_spec.num_shards_map
306-
tensor_shape = list(self._sharding_spec.shape)
307-
assert len(num_shards_map) == len(tensor_shape)
308-
for i, (size, num_shards) in enumerate(zip(tensor_shape, num_shards_map)):
309-
if size % num_shards != 0:
310-
raise NotImplementedError(
311-
"FSDP+TP sharding does not support uneven sharding for now: "
312-
f"tensor dim {i} has size {size} which cannot be evenly sharded "
313-
f"into {num_shards} shards."
314-
)
305+
if split_factor > 1: # FSDP has strided sharding on tensor dim 0
306+
num_shards_map = self._sharding_spec.num_shards_map
307+
tensor_shape = list(self._sharding_spec.shape)
308+
assert len(num_shards_map) == len(tensor_shape)
309+
for i, (size, num_shards) in enumerate(
310+
zip(tensor_shape, num_shards_map)
311+
):
312+
if size % num_shards != 0:
313+
raise NotImplementedError(
314+
"FSDP+TP sharding does not support uneven sharding for now: "
315+
f"tensor dim {i} has size {size} which cannot be evenly "
316+
f"sharded into {num_shards} shards."
317+
)
315318

316319
param_data = cast(DTensor, param)._local_tensor
317320
else:

0 commit comments

Comments
 (0)
0