8000 [FSDP2+TP] Disable 2D state_dict by wz337 · Pull Request #129519 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[FSDP2+TP] Disable 2D state_dict #129519

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

wz337
Copy link
Contributor
@wz337 wz337 commented Jun 25, 2024

Fixes #ISSUE_NUMBER

Gonna fill in the RFC but just want to run CI to see if anything else breaks.

Test:

python test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_raise_not_implemented_state_dict_if_2d

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @fegin @XilunWu @wanchaol @fduwjj @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link
pytorch-bot bot commented Jun 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/129519

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 New Failures, 1 Unrelated Failure

As of commit 790e786 with merge base 1c75ddf (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Jun 25, 2024
@wz337 wz337 force-pushed the disallow_state_dict_call_on_2d branch from b40cc9f to 6c7d7b8 Compare June 25, 2024 22:08
@wz337 wz337 changed the title [FSDP2+TP] Disable 2D state_dict [WIP][FSDP2+TP] Disable 2D state_dict Jun 25, 2024
@wz337 wz337 requested review from fegin, awgu, XilunWu and lessw2020 June 25, 2024 22:59
@wz337 wz337 force-pushed the disallow_state_dict_call_on_2d branch 6 times, most recently from 77bc05f to 5a5b5e7 Compare < 8000 /a> June 27, 2024 03:38
@@ -223,6 +223,22 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device):
tensor_meta=self._tp_spec.tensor_meta,
)
param_data = cast(DTensor, param)._local_tensor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awgu Want to check with you whether to disable the 2D state_dict like this is ok with you.

Originally thinking about doing it in _register_state_dict_hooks, but since _register_state_dict_hooks happens during lazy_init so ended up registering the hooks in param init.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/_composable/fsdp/_fsdp_param_group.py#L469

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can move this to the end of FSDPParamGroup.__init__ so that we do not need the if len(cur_module._state_dict_pre_hooks) == 0: check in the case that a module has >1 parameter.

Copy link
Contributor Author
@wz337 wz337 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking to add this in FSDPParamGroup initially, but since we only wanted to add this hook for 2D, then we would need an if in FSDPParamGroup.__init__ as well. Something like this:

if any(len(fsdp_param._spmd_placements)>1 for fsdp_param in self.fsdp_params):
    self.module.register_state_dict_pre_hook(...)
    self.module._register_load_state_dict_pre_hook(...)

In this case, we would also run the extra if check for 1D vs. We only register the hook if we encounter a 2D placements so nothing changes from 1D side, but we will do the _state_dict_pre_hooks check for every param init for 2D.

Not sure which one is more preferrable to you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed with @awgu offline to move the check to the end of FSDPParamGroup.init.

@wz337 wz337 marked this pull request as ready for review June 27, 2024 15:53
@wz337 wz337 force-pushed the disallow_state_dict_call_on_2d branch 2 times, most recently from ee146c3 to 50a59e5 Compare June 28, 2024 16:56
@wz337 wz337 force-pushed the disallow_state_dict_call_on_2d branch from 50a59e5 8000 to 790e786 Compare June 28, 2024 16:57
Copy link
Collaborator
@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thank you!

@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Jun 29, 2024
@wz337
Copy link
Contributor Author
wz337 commented Jul 1, 2024

@pytorchmergebot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 1, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-focal-py3.8-clang10 / test (dynamo, 3, 3, linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64-mps / test (mps, 1, 1, macos-m1-14)

Details for Dev Infra team Raised by workflow job

XilunWu added a commit that referenced this pull request Jul 22, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jul 23, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Jul 23, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have 
E7EE
`FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/test_fully_shard_training.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_state_dict.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`
`pytest test/distributed/_composable/fsdp/test_fully_shard_init.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 1, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz 

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 5, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 5, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Add a new property `num_shards_map` to `DTensorSpec` denoting how many shards each tensor dimension has. This is necessary for constructing `_StridedShard` placement when we call `distribute_tensor(dtensor_tp, dp_device_mesh, [Shard(0)])` and the `split_factor` argument will just be the number of shards on that sharding tensor dim.
3. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
4. Re-enabled the tests that were disabled in #129519 and removed relevant code

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 6, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 6, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 6, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 6, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 6, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 6, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 7, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 7, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 7, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 7, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 7, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 7, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 8, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 8, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 8, 2024
…sent nested sharding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
XilunWu added a commit that referenced this pull request Aug 8, 2024
…ding for correct full_tensor() result"


Fixes issue #129229 #129206 
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519 

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o zhaojuanmao mrshenli rohan-varma chauhang LucasLLC MeetVadakkanchery mhorowitz

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Aug 8, 2024
…rrect full_tensor() result (#130760)

Fixes issue #129229 #129206
**Summary**

1. Have `FSDP` choose `_StridedShard` placement for FSDP+TP sharding
2. Added a parity test to FSDP to ensure that FSDP+TP sharding (i.e. strided) and simply TP sharding (i.e. non-strided) has the same `full_tensor()` result
3. Re-enabled the tests that were disabled in #129519

**test**
`pytest test/distributed/_composable/fsdp/`
`pytest test/distributed/_composable/test_composability/test_2d_composability.py`
`pytest test/distributed/checkpoint/fsdp/test_fsdp_dsd.py`

Differential Revision: [D60606114](https://our.internmc.facebook.com/intern/diff/D60606114)
Pull Request resolved: #130760
Approved by: https://github.com/wanchaol, https://github.com/fegin, https://github.com/wz337
ghstack dependencies: #126697, #130239, #132391, #131408
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants
0