You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update base for Update on "[FSDP][dtensor] use _StridedShard to represent nested sharding for correct full_tensor() result"
Fixes issue #129229#129206
**Summary**
1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519
**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz
Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)
[ghstack-poisoned]
0 commit comments