8000 [DTensor] enable SimpleFSDP's composability with Tensor Parallel by ruisizhang123 · Pull Request #152286 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[DTensor] enable SimpleFSDP's composability with Tensor Parallel #152286

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ruisizhang123
Copy link
Contributor
@ruisizhang123 ruisizhang123 commented Apr 27, 2025

This PR adds support for SimpleFSDP's composability with Tensor Parallel. This is done by enabling a DTensor redistribution from the FSDP submesh toward TP submesh in distribute_tensor API.

  1. Correctness: The end-to-end SimpleFSDP TP integration has been proved to work in the PR from this fork: support Simple FSDP + TP by enabling distribute_tensor on DTensor inputs tianyu-l/pytorch_intern24#25. Per the discussion with Tianyu, this PR also adds _StridedShard following FSDP2 to be compatible with distributed checkpointing. The newly benchmarked results demonstrated it works properly in this torchtitan PR: [SimpleFSDP] Add tensor parallel support torchtitan#1148.

  2. Example Usage: There is an example in TorchTian's SimpleFSDP implementation: [SimpleFSDP] Add tensor parallel support torchtitan#1148.

In the example below, given an input DTensor tensor sharded in fully_shard mode (FSDP) with placement (Shard(0),). If the device_mesh is a 2D mesh with FSDP & TP dim, this tensor is re-distributed from FSDP placement to the TP placement.

distribute_tensor(tensor, device_mesh, param_sharding)

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @chauhang @amjames

Copy link
pytorch-bot bot commented Apr 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152286

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit f1d228b with merge base 70d7638 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue labels Apr 27, 2025
Copy link
Contributor

This PR needs a release notes: label

If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 28, 2025
Copy link
Contributor
@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this PR is to enable Simple FSDP + TP, it's a little to hard code this variables in the generic DTensor APIs. Users may do Shard(0) and Shard(0) but they are not from Simple FSDP + TP.

Also, we assert mesh_dim_names have to exist, which I believe we have not strictly enforced the generic DeviceMesh and DTensor use cases.

Or are we saying that if distribute_tensor() is used with a DTensor, the user has to follow this restriction?

@ruisizhang123
Copy link
Contributor Author
ruisizhang123 commented Apr 28, 2025

generic DTensor APIs. Users may do Shard(0) and Shard(0) but they are not from Simple FSDP + TP.

Also, we assert mesh_dim_names have to exist, which I believe we have not strictly enforced the generic DeviceMesh and DTensor use cases.

This re-distribute behavior from DTensor submesh --> Device submesh is only used by SimpleFSDP rn. (@tianyu-l could confirm) I agree it's a bit hardcoded to be SimpleFSDP-specific, whereas FSDP2 handles this in its FSDP2 param folder.

Wonder if it would make more sense to have another arg input to specify if this distribute_tensor is from SimpleFSDP. The redistribute will be activated only when SimpleFSDP is enable. 🤔️

@fegin
Copy link
Contributor
fegin commented Apr 30, 2025

IMHO, DTensor as a parallelism infra shouldn't hard code parallelism information but should be generalized to support different parallelism way.

Copy link
Collaborator
@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should hard code any special things about FSDP or TP in this API

assert outer_mesh.mesh_dim_names is not None
assert inner_mesh.mesh_dim_names is not None

# if the global mesh follows (device_mesh, dtensor_mesh), it means Simple FSDP + TP is enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this mean? "(device_mesh, dtensor_mesh)"

outer_global_mesh.mesh_dim_names
== outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
):
dp_shard_tp = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should not encode any "dp" "tp" term in this API, at very least you should rename to "is_strided_shard"

Comment on lines 783 to 798
current_spec = DTensorSpec(
mesh=outer_mesh,
placements=(Replicate(),),
tensor_meta=inner_spec.tensor_meta,
)
target_spec = DTensorSpec(
mesh=outer_mesh,
placements=(placements[0],),
tensor_meta=inner_spec.tensor_meta,
)

result_tensor = redistribute_local_tensor(
tensor._local_tensor,
current_spec=current_spec,
target_spec=target_spec,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain why you need to redistribute here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wrote this originally at https://github.com/tianyu-l/pytorch_intern24/pull/25/files#diff-6244a472a481d8d8aa0ba7b075cd1e67290b325277fab41df5145209235a0abcR713

This is a convenient way to shard the tensor (from Replicate to Shard), without incurring comms.

Copy link
Contributor
@tianyu-l tianyu-l left a comment
8000

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should not hardcode parallelism-related code into DTensor.

To provide some context:
This is to make SimpleFSDP work with TP, at this line https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/simple_fsdp/simple_fsdp.py#L177
I would hope it can also replace the glue code in FSDP2, per this comment https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L281

It sounds a non-trivial effort:
Even for the 2D case (where we call distribute_tensor on a 1D DTensor), there are three cases:

  1. there is no sharding interference, e.g. (Replicate, Shard(0)), (Shard(0), Shard(1)), etc. In FSDP + TP, it happens when TP sharding doesn't happen on the FSDP shard dim (default 0).
  2. two Shard() on the same dimension, first on the outer mesh, then on the inner mesh. E.g. Shard(0) -> (Shard(0), Shard(0))
  3. two Shard() on the same dimension, first on the inner mesh, then on the inner mesh. E.g. Shard(0) -> (_StridedShard(0), Shard(0)). In FSDP + TP, it happens when TP sharding happens on the FSDP shard dim (default 0).

In principle this should be doable, with some evolvement of device mesh APIs.
Before that, we may put the glue code in SimpleFSDP, and replace this distribute_tensor call with a wrapper which dispatches to the glue code if the input tensor is a DTensor, just like in FSDP2.

cc: @wanchaol @fegin @awgu

Comment on lines 783 to 798
current_spec = DTensorSpec(
mesh=outer_mesh,
placements=(Replicate(),),
tensor_meta=inner_spec.tensor_meta,
)
target_spec = DTensorSpec(
mesh=outer_mesh,
placements=(placements[0],),
tensor_meta=inner_spec.tensor_meta,
)

result_tensor = redistribute_local_tensor(
tensor._local_tensor,
current_spec=current_spec,
target_spec=target_spec,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I wrote this originally at https://github.com/tianyu-l/pytorch_intern24/pull/25/files#diff-6244a472a481d8d8aa0ba7b075cd1e67290b325277fab41df5145209235a0abcR713

This is a convenient way to shard the tensor (from Replicate to Shard), without incurring comms.

@ruisizhang123
Copy link
Contributor Author
ruisizhang123 commented May 6, 2025

I agree we should not hardcode parallelism-related code into DTensor.

To provide some context: This is to make SimpleFSDP work with TP, at this line https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/simple_fsdp/simple_fsdp.py#L177 I would hope it can also replace the glue code in FSDP2, per this comment https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_param.py#L281

It sounds a non-trivial effort: Even for the 2D case (where we call distribute_tensor on a 1D DTensor), there are three cases:

  1. there is no sharding interference, e.g. (Replicate, Shard(0)), (Shard(0), Shard(1)), etc. In FSDP + TP, it happens when TP sharding doesn't happen on the FSDP shard dim (default 0).
  2. two Shard() on the same dimension, first on the outer mesh, then on the inner mesh. E.g. Shard(0) -> (Shard(0), Shard(0))
  3. two Shard() on the same dimension, first on the inner mesh, then on the inner mesh. E.g. Shard(0) -> (_StridedShard(0), Shard(0)). In FSDP + TP, it happens when TP sharding happens on the FSDP shard dim (default 0).

In principle this should be doable, with some evolvement of device mesh APIs. Before that, we may put the glue code in SimpleFSDP, and replace this distribute_tensor call with a wrapper which dispatches to the glue code if the input tensor is a DTensor, just like in FSDP2.

cc: @wanchaol @fegin @awgu

Per the discussion with Tianyu, it might be easier to put the TP DTensor distribute as a private function in torchtitan: pytorch/torchtitan#1148, instead of significantly changing the DTensor code. I've updated accordingly.

But we would still need to change the dynamo code to ensure _StridedShard can pass the dtype check. 🤔️ I've updated this pytorch pr to fix this dtype checking issue.

cc: @wanchaol @fegin @awgu

@tianyu-l tianyu-l requested a review from bdhirsh May 14, 2025 08:27
@tianyu-l
Copy link
Contributor

@bdhirsh does this change look OK to you? If so could you stamp? Thanks!

ruisizhang123 added a commit to pytorch/torchtitan that referenced this pull request May 15, 2025
This PR adds tensor parallel (TP) support for SimpleFSDP, together with
[this pr](pytorch/pytorch#152286) from PyTorch.

**(This is a placeholder for now, and shall be merged only after the
PyTorch pr is landed.)**

**Profile Trace & Convergence on debug model:**

<img width="1578" alt="Screenshot 2025-04-27 at 3 48 35 PM"
src="https://github.com/user-attachments/assets/42d36eb7-51c1-4f60-aa2e-20b58d8c491b"
/>


<img width="1369" alt="Screenshot 2025-04-27 at 3 31 46 PM"
src="https://github.com/user-attachments/assets/79e25f88-9e8a-4fe2-b7c9-a7e9747d19a4"
/>


The loss curves are all benchmarked with `training.seed = 42` on 4 GPUs.

The end-to-end SimpleFSDP TP integration has been proved to work in the
PR from this fork: tianyu-l/pytorch_intern24#25.
Per the discussion with Tianyu, the new PyTorch PR has add
`_StridedShard` following FSDP2 to be compatible with distributed
checkpointing. The newly benchmarked results demonstrated it works
properly.
wwwjn pushed a commit to pytorch/torchtitan that referenced this pull request May 16, 2025
This PR adds tensor parallel (TP) support for SimpleFSDP, together with
[this pr](pytorch/pytorch#152286) from PyTorch.

**(This is a placeholder for now, and shall be merged only after the
PyTorch pr is landed.)**

**Profile Trace & Convergence on debug model:**

<img width="1578" alt="Screenshot 2025-04-27 at 3 48 35 PM"
src="https://github.com/user-attachments/assets/42d36eb7-51c1-4f60-aa2e-20b58d8c491b"
/>


<img width="1369" alt="Screenshot 2025-04-27 at 3 31 46 PM"
src="https://github.com/user-attachments/assets/79e25f88-9e8a-4fe2-b7c9-a7e9747d19a4"
/>


The loss curves are all benchmarked with `training.seed = 42` on 4 GPUs.

The end-to-end SimpleFSDP TP integration has been proved to work in the
PR from this fork: tianyu-l/pytorch_intern24#25.
Per the discussion with Tianyu, the new PyTorch PR has add
`_StridedShard` following FSDP2 to be compatible with distributed
checkpointing. The newly benchmarked results demonstrated it works
properly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamo oncall: distributed Add this issue/PR to distributed oncall triage queue open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants
0