10000 Pipeline Parallelism Fails when stage input does not produce gradients in all stages. · Issue #152827 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Pipeline Parallelism Fails when stage input does not produce gradients in all stages. #152827

8000
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
man2machine opened this issue May 5, 2025 · 4 comments
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@man2machine
Copy link
man2machine commented May 5, 2025

🐛 Describe the bug

TLDR: Pipeline parallelism fails if stage input does not have gradients produced

Consider the case where a outputs from each pipeline stage is passed to the next stage, but whether or not the output is used or not for a particular batch is conditional (based on the code of the model). Hence, in many cases (such as in conditional or mixture models), these weights may not be used for a particular stage, thus resulting in an error from get_bwd_send_ops: "[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} and is expecting to send gradients to stage {grad_recv_stage}".

As a result, such a model fails when using pipeline parallelism, although with FSDP (or no parallelism) it has no issues. This is caused because for each input tensor to a stage, after calculating the gradient for that tensor, if there is no gradient it produces that error, even if that tensor would otherwise be passed onto a subsequent stage that would result in gradients being produced.

Currently it uses the dist.isend to send the tensor, but in order to send None, a different asynchronous P2P commutation operation is needed, to be able to asynchronously send or recv objects (which may or may not be tensors).

It would be great if this can be implemented, as this pipeline parallelism is critical to achieving high throughput in distributed execution, and conditional or mixture models are limited due to this bug.

Versions

Pytorch 2.6+

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k

@zou3519 zou3519 added the oncall: distributed Add this issue/PR to distributed oncall triage queue label May 6, 2025
@man2machine
Copy link
Author
man2machine commented May 12, 2025

Update on this issue:

I was partially able to workaround this by adding in an additional graph_connection output to the pipeline stage outputs graph_connection = graph_connection + (0.0 * param.flatten()[0]), tracking all unused stage outputs and model parameters via backward hooks, and eventually adding this graph connection output to the loss in the final stage.

I have tested this on various scenarios and the workaround does confirm that the disconnected gradients (producing None in the backward) in PP, PP + FSDP, and FSDP is the issue.

The workaround does allow the code to run in the case of PP only and FSDP + PP with only one level of FSDP wrapping. However, it doesn't solve it with 2 levels of FSDP wrapping (like in TorchTitan where the full model and the transformer blocks are fully sharded) due to aten.select.int not having a registered sharding strategy. If I avoid the indexing, I get that there is mixed DTensor and Tensor for aten.add.Tensor. I believe this can be fixed b 8000 y adding the graph connection workaround again inside the transformer submodules that fully shard is applied to as well.

All of these workarounds introduced significant additional user code, and may introduce additional unneeded communication from the additional graph connections. A proper fix for this issue is needed so the issue still stands.

@fegin fegin added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label May 12, 2025
@fegin
Copy link
Contributor
fegin commented May 12, 2025

@H-Huang can you take a look? Thanks!

@H-Huang
Copy link
Member
H-Huang commented May 12, 2025

Do you have a minimal example of this bug that we could use to aid the feature development (and eventually turn into a test for CI)? That would be really helpful!

It sounds like the fix is to set gradients to 0 even if the inputs grads are None, that sounds doable to update in our PP framework. I'm assuming this use case is only for stages when handling the grads for multiple inputs?

An alternative is to skip gradient sending if no grads are detected, but that would also require the previous stage to also know to set the recv, which seems a bit more challenging so I think the former is better.

@H-Huang H-Huang added pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html module: pipelining Pipeline Parallelism labels May 12, 2025
@man2machine
Copy link
Author
man2machine commented May 13, 2025

I can work on a minimal example of the bug after a few of days.

This bug happens in:

  • PP when an stage input tensor does not have gradients
  • FSDP when a parameter tensor does not have gradients
  • PP + FSDP when an stage input or parameter tensor does not have gradients

To work around this, I manually added a graph connection tensor I described earlier to force PyTorch to connect the tensors into the graph. I added hooks to each input tensor and parameter tensor to check if it was used or not. It ends up being very hairy when you perform nested fully_shard like in TorchTitan, and I still haven't finished the workaround for that yet.

Setting the gradients to be 0 sounds like the right approach, but I suggest being careful as to avoid unnecessary communication overhead. Specifically, sending a full tensor of just zeros could waste megabytes or gigabytes of bandwidth depending on the size of the tensor. So there needs to be a way of sending a zero update such a special sentinel/marker or some other construct. And we would want to prevent additional computation from happening based on those zero gradients. Otherwise, much of the benefit that arises from mixture of expert models is lost.

Skipping sending gradients would be difficult as the recv stage would not necessarily know beforehand which inputs/parameters are going to be used in the gradients until the forward pass for the send stage has completed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: pipelining Pipeline Parallelism oncall: distributed Add this issue/PR to distributed oncall triage queue pipeline parallelism Issues related to https://pytorch.org/docs/master/pipeline.html triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants
0