-
Notifications
You must be signed in to change notification settings - Fork 24.2k
[DTensor] enable SimpleFSDP's composability with Tensor Parallel #152286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152286
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ❌ 1 New FailureAs of commit f1d228b with merge base 70d7638 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While this PR is to enable Simple FSDP + TP, it's a little to hard code this variables in the generic DTensor APIs. Users may do Shard(0) and Shard(0) but they are not from Simple FSDP + TP
.
Also, we assert mesh_dim_names
have to exist, which I believe we have not strictly enforced the generic DeviceMesh and DTensor use cases.
Or are we saying that if distribute_tensor()
is used with a DTensor, the user has to follow this restriction?
This re-distribute behavior from DTensor submesh --> Device submesh is only used by SimpleFSDP rn. (@tianyu-l could confirm) I agree it's a bit hardcoded to be SimpleFSDP-specific, whereas FSDP2 handles this in its FSDP2 param folder. Wonder if it would make more sense to have another arg input to specify if this |
IMHO, DTensor as a parallelism infra shouldn't hard code parallelism information but should be generalized to support different parallelism way. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should hard code any special things about FSDP or TP in this API
torch/distributed/tensor/_api.py
Outdated
assert outer_mesh.mesh_dim_names is not None | ||
assert inner_mesh.mesh_dim_names is not None | ||
|
||
# if the global mesh follows (device_mesh, dtensor_mesh), it means Simple FSDP + TP is enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this mean? "(device_mesh, dtensor_mesh)"
torch/distributed/tensor/_api.py
Outdated
outer_global_mesh.mesh_dim_names | ||
== outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names | ||
): | ||
dp_shard_tp = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should not encode any "dp" "tp" term in this API, at very least you should rename to "is_strided_shard"
torch/distributed/tensor/_api.py
Outdated
current_spec = DTensorSpec( | ||
mesh=outer_mesh, | ||
placements=(Replicate(),), | ||
tensor_meta=inner_spec.tensor_meta, | ||
) | ||
target_spec = DTensorSpec( | ||
mesh=outer_mesh, | ||
placements=(placements[0],), | ||
tensor_meta=inner_spec.tensor_meta, | ||
) | ||
|
||
result_tensor = redistribute_local_tensor( | ||
tensor._local_tensor, | ||
current_spec=current_spec, | ||
target_spec=target_spec, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you explain why you need to redistribute here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I wrote this originally at https://github.com/tianyu-l/pytorch_intern24/pull/25/files#diff-6244a472a481d8d8aa0ba7b075cd1e67290b325277fab41df5145209235a0abcR713
This is a convenient way to shard the tensor (from Replicate to Shard), without incurring comms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should not hardcode parallelism-related code into DTensor.
To provide some context:
This is to make SimpleFSDP work with TP, at this line https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/simple_fsdp/simple_fsdp.py#L177
I would hope it can also replace the glue code in FSDP2, per this comment https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L281
It sounds a non-trivial effort:
Even for the 2D case (where we call distribute_tensor
on a 1D DTensor), there are three cases:
- there is no sharding interference, e.g.
(Replicate, Shard(0))
,(Shard(0), Shard(1))
, etc. In FSDP + TP, it happens when TP sharding doesn't happen on the FSDP shard dim (default 0). - two
Shard()
on the same dimension, first on the outer mesh, then on the inner mesh. E.g.Shard(0)
->(Shard(0), Shard(0))
- two
Shard()
on the same dimension, first on the inner mesh, then on the inner mesh. E.g.Shard(0)
->(_StridedShard(0), Shard(0))
. In FSDP + TP, it happens when TP sharding happens on the FSDP shard dim (default 0).
In principle this should be doable, with some evolvement of device mesh APIs.
Before that, we may put the glue code in SimpleFSDP, and replace this distribute_tensor
call with a wrapper which dispatches to the glue code if the input tensor is a DTensor, just like in FSDP2.
torch/distributed/tensor/_api.py
Outdated
current_spec = DTensorSpec( | ||
mesh=outer_mesh, | ||
placements=(Replicate(),), | ||
tensor_meta=inner_spec.tensor_meta, | ||
) | ||
target_spec = DTensorSpec( | ||
mesh=outer_mesh, | ||
placements=(placements[0],), | ||
tensor_meta=inner_spec.tensor_meta, | ||
) | ||
|
||
result_tensor = redistribute_local_tensor( | ||
tensor._local_tensor, | ||
current_spec=current_spec, | ||
target_spec=target_spec, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I wrote this originally at https://github.com/tianyu-l/pytorch_intern24/pull/25/files#diff-6244a472a481d8d8aa0ba7b075cd1e67290b325277fab41df5145209235a0abcR713
This is a convenient way to shard the tensor (from Replicate to Shard), without incurring comms.
Per the discussion with Tianyu, it might be easier to put the TP DTensor distribute as a private function in torchtitan: pytorch/torchtitan#1148, instead of significantly changing the DTensor code. I've updated accordingly. But we would still need to change the dynamo code to ensure |
@bdhirsh does this change look OK to you? If so could you stamp? Thanks! |
This PR adds tensor parallel (TP) support for SimpleFSDP, together with [this pr](pytorch/pytorch#152286) from PyTorch. **(This is a placeholder for now, and shall be merged only after the PyTorch pr is landed.)** **Profile Trace & Convergence on debug model:** <img width="1578" alt="Screenshot 2025-04-27 at 3 48 35 PM" src="https://github.com/user-attachments/assets/42d36eb7-51c1-4f60-aa2e-20b58d8c491b" /> <img width="1369" alt="Screenshot 2025-04-27 at 3 31 46 PM" src="https://github.com/user-attachments/assets/79e25f88-9e8a-4fe2-b7c9-a7e9747d19a4" /> The loss curves are all benchmarked with `training.seed = 42` on 4 GPUs. The end-to-end SimpleFSDP TP integration has been proved to work in the PR from this fork: tianyu-l/pytorch_intern24#25. Per the discussion with Tianyu, the new PyTorch PR has add `_StridedShard` following FSDP2 to be compatible with distributed checkpointing. The newly benchmarked results demonstrated it works properly.
This PR adds tensor parallel (TP) support for SimpleFSDP, together with [this pr](pytorch/pytorch#152286) from PyTorch. **(This is a placeholder for now, and shall be merged only after the PyTorch pr is landed.)** **Profile Trace & Convergence on debug model:** <img width="1578" alt="Screenshot 2025-04-27 at 3 48 35 PM" src="https://github.com/user-attachments/assets/42d36eb7-51c1-4f60-aa2e-20b58d8c491b" /> <img width="1369" alt="Screenshot 2025-04-27 at 3 31 46 PM" src="https://github.com/user-attachments/assets/79e25f88-9e8a-4fe2-b7c9-a7e9747d19a4" /> The loss curves are all benchmarked with `training.seed = 42` on 4 GPUs. The end-to-end SimpleFSDP TP integration has been proved to work in the PR from this fork: tianyu-l/pytorch_intern24#25. Per the discussion with Tianyu, the new PyTorch PR has add `_StridedShard` following FSDP2 to be compatible with distributed checkpointing. The newly benchmarked results demonstrated it works properly.
This PR adds support for SimpleFSDP's composability with Tensor Parallel. This is done by enabling a DTensor redistribution from the FSDP submesh toward TP submesh in
distribute_tensor
API.Correctness: The end-to-end SimpleFSDP TP integration has been proved to work in the PR from this fork: support Simple FSDP + TP by enabling distribute_tensor on DTensor inputs tianyu-l/pytorch_intern24#25. Per the discussion with Tianyu, this PR also adds _StridedShard following FSDP2 to be compatible with distributed checkpointing. The newly benchmarked results demonstrated it works properly in this torchtitan PR: [SimpleFSDP] Add tensor parallel support torchtitan#1148.
Example Usage: There is an example in TorchTian's SimpleFSDP implementation: [SimpleFSDP] Add tensor parallel support torchtitan#1148.
In the example below, given an input DTensor
tensor
sharded infully_shard
mode (FSDP) with placement(Shard(0),)
. If thedevice_mesh
is a 2D mesh with FSDP & TP dim, thistensor
is re-distributed from FSDP placement to the TP placement.cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames