You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update base for Update on "[FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result"
Fixes issue #129229#129206
**Summary**
1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code
**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang
[ghstack-poisoned]
0 commit comments